braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for data manipulation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .serialization import (
|
|
6
|
+
_check_save_dir_empty,
|
|
7
|
+
load_concat_dataset,
|
|
8
|
+
save_concat_dataset,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def __getattr__(name):
|
|
13
|
+
# ideas from https://stackoverflow.com/a/57110249/1469195
|
|
14
|
+
import importlib
|
|
15
|
+
from warnings import warn
|
|
16
|
+
|
|
17
|
+
if name == "create_from_X_y":
|
|
18
|
+
warn(
|
|
19
|
+
"create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
|
|
20
|
+
)
|
|
21
|
+
xy = importlib.import_module("..datasets.xy", __package__)
|
|
22
|
+
return xy.create_from_X_y
|
|
23
|
+
if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
|
|
24
|
+
warn(
|
|
25
|
+
f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
|
|
26
|
+
)
|
|
27
|
+
mne = importlib.import_module("..datasets.mne", __package__)
|
|
28
|
+
return mne.__dict__[name]
|
|
29
|
+
if name in [
|
|
30
|
+
"scale",
|
|
31
|
+
"exponential_moving_demean",
|
|
32
|
+
"exponential_moving_standardize",
|
|
33
|
+
"filterbank",
|
|
34
|
+
"preprocess",
|
|
35
|
+
"Preprocessor",
|
|
36
|
+
]:
|
|
37
|
+
warn(
|
|
38
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
39
|
+
)
|
|
40
|
+
preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
|
|
41
|
+
return preprocess.__dict__[name]
|
|
42
|
+
if name in ["create_windows_from_events", "create_fixed_length_windows"]:
|
|
43
|
+
warn(
|
|
44
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
45
|
+
)
|
|
46
|
+
windowers = importlib.import_module("..preprocessing.windowers", __package__)
|
|
47
|
+
return windowers.__dict__[name]
|
|
48
|
+
|
|
49
|
+
raise AttributeError("No possible import named " + name)
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions for storing and loading of windows datasets.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
6
|
+
#
|
|
7
|
+
# License: BSD (3-clause)
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import pickle
|
|
12
|
+
import warnings
|
|
13
|
+
from glob import glob
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import mne
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from joblib import Parallel, delayed
|
|
19
|
+
|
|
20
|
+
from ..datasets.base import (
|
|
21
|
+
BaseConcatDataset,
|
|
22
|
+
BaseDataset,
|
|
23
|
+
EEGWindowsDataset,
|
|
24
|
+
WindowsDataset,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
29
|
+
warnings.warn(
|
|
30
|
+
'"save_concat_dataset()" is deprecated and will be removed in'
|
|
31
|
+
" the future. Use dataset.save() instead.",
|
|
32
|
+
UserWarning,
|
|
33
|
+
)
|
|
34
|
+
concat_dataset.save(path=path, overwrite=overwrite)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
38
|
+
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
39
|
+
files.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
path: pathlib.Path
|
|
44
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
45
|
+
preload: bool
|
|
46
|
+
Whether to preload the data.
|
|
47
|
+
ids_to_load: None | list(int)
|
|
48
|
+
Ids of specific files to load.
|
|
49
|
+
target_name: None or str
|
|
50
|
+
Load specific description column as target. If not given, take saved
|
|
51
|
+
target name.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
|
|
56
|
+
"""
|
|
57
|
+
# assume we have a single concat dataset to load
|
|
58
|
+
is_raw = (path / "0-raw.fif").is_file()
|
|
59
|
+
assert not (not is_raw and target_name is not None), (
|
|
60
|
+
"Setting a new target is only supported for raws."
|
|
61
|
+
)
|
|
62
|
+
is_epochs = (path / "0-epo.fif").is_file()
|
|
63
|
+
paths = [path]
|
|
64
|
+
# assume we have multiple concat datasets to load
|
|
65
|
+
if not (is_raw or is_epochs):
|
|
66
|
+
is_raw = (path / "0" / "0-raw.fif").is_file()
|
|
67
|
+
is_epochs = (path / "0" / "0-epo.fif").is_file()
|
|
68
|
+
paths = path.glob("*/")
|
|
69
|
+
paths = sorted(paths, key=lambda p: int(p.name))
|
|
70
|
+
if ids_to_load is not None:
|
|
71
|
+
paths = [paths[i] for i in ids_to_load]
|
|
72
|
+
ids_to_load = None
|
|
73
|
+
# if we have neither a single nor multiple datasets, something went wrong
|
|
74
|
+
assert is_raw or is_epochs, (
|
|
75
|
+
f"Expect either raw or epo to exist in {path} or in {path / '0'}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
datasets = []
|
|
79
|
+
for path in paths:
|
|
80
|
+
if is_raw and target_name is None:
|
|
81
|
+
target_file_name = path / "target_name.json"
|
|
82
|
+
target_name = json.load(open(target_file_name, "r"))["target_name"]
|
|
83
|
+
|
|
84
|
+
all_signals, description = _load_signals_and_description(
|
|
85
|
+
path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
|
|
86
|
+
)
|
|
87
|
+
for i_signal, signal in enumerate(all_signals):
|
|
88
|
+
if is_raw:
|
|
89
|
+
datasets.append(
|
|
90
|
+
BaseDataset(
|
|
91
|
+
signal, description.iloc[i_signal], target_name=target_name
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
|
|
96
|
+
concat_ds = BaseConcatDataset(datasets)
|
|
97
|
+
for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
|
|
98
|
+
kwarg_path = path / ".".join([kwarg_name, "json"])
|
|
99
|
+
if kwarg_path.exists():
|
|
100
|
+
with open(kwarg_path, "r") as f:
|
|
101
|
+
kwargs = json.load(f)
|
|
102
|
+
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
103
|
+
setattr(concat_ds, kwarg_name, kwargs)
|
|
104
|
+
return concat_ds
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
|
|
108
|
+
all_signals = []
|
|
109
|
+
file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
|
|
110
|
+
description_df = pd.read_json(path / "description.json")
|
|
111
|
+
if ids_to_load is None:
|
|
112
|
+
file_names = path.glob(f"*{file_name.lstrip('{}')}")
|
|
113
|
+
# Extract ids, e.g.,
|
|
114
|
+
# '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
|
|
115
|
+
# '11-raw.fif' -> 11
|
|
116
|
+
ids_to_load = sorted(
|
|
117
|
+
[int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
|
|
118
|
+
)
|
|
119
|
+
for i in ids_to_load:
|
|
120
|
+
fif_file = path / file_name.format(i)
|
|
121
|
+
all_signals.append(_load_signals(fif_file, preload, is_raw))
|
|
122
|
+
description_df = description_df.iloc[ids_to_load]
|
|
123
|
+
return all_signals, description_df
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _load_signals(fif_file, preload, is_raw):
|
|
127
|
+
# Reading the raw file from pickle if it has been save before.
|
|
128
|
+
# The pickle file only contain the raw object without the data.
|
|
129
|
+
pkl_file = fif_file.with_suffix(".pkl")
|
|
130
|
+
if pkl_file.exists():
|
|
131
|
+
with open(pkl_file, "rb") as f:
|
|
132
|
+
signals = pickle.load(f)
|
|
133
|
+
|
|
134
|
+
# If the file has been moved together with the pickle file, make sure
|
|
135
|
+
# the path links to correct fif file.
|
|
136
|
+
signals._fname = str(fif_file)
|
|
137
|
+
if preload:
|
|
138
|
+
signals.load_data()
|
|
139
|
+
return signals
|
|
140
|
+
|
|
141
|
+
# If pickle didn't exist read via mne (likely slower) and save pkl after
|
|
142
|
+
if is_raw:
|
|
143
|
+
signals = mne.io.read_raw_fif(fif_file, preload=preload)
|
|
144
|
+
elif fif_file.name.endswith("-epo.fif"):
|
|
145
|
+
signals = mne.read_epochs(fif_file, preload=preload)
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError("fif_file must end with raw.fif or epo.fif.")
|
|
148
|
+
|
|
149
|
+
# Only do this for raw objects. Epoch objects are not picklable as they
|
|
150
|
+
# hold references to open files in `signals._raw[0].fid`.
|
|
151
|
+
if is_raw:
|
|
152
|
+
# Saving the raw file without data into a pickle file, so it can be
|
|
153
|
+
# retrieved faster on the next use of this dataset.
|
|
154
|
+
with open(pkl_file, "wb") as f:
|
|
155
|
+
if preload:
|
|
156
|
+
data = signals._data
|
|
157
|
+
signals._data, signals.preload = None, False
|
|
158
|
+
pickle.dump(signals, f)
|
|
159
|
+
if preload:
|
|
160
|
+
signals._data, signals.preload = data, True
|
|
161
|
+
|
|
162
|
+
return signals
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
166
|
+
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
167
|
+
files.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
path: str | pathlib.Path
|
|
172
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
173
|
+
preload: bool
|
|
174
|
+
Whether to preload the data.
|
|
175
|
+
ids_to_load: list of int | None
|
|
176
|
+
Ids of specific files to load.
|
|
177
|
+
target_name: str | list | None
|
|
178
|
+
Load specific description column as target. If not given, take saved
|
|
179
|
+
target name.
|
|
180
|
+
n_jobs: int
|
|
181
|
+
Number of jobs to be used to read files in parallel.
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
|
|
186
|
+
"""
|
|
187
|
+
# Make sure we always work with a pathlib.Path
|
|
188
|
+
path = Path(path)
|
|
189
|
+
|
|
190
|
+
# if we encounter a dataset that was saved in 'the old way', call the
|
|
191
|
+
# corresponding 'old' loading function
|
|
192
|
+
if _is_outdated_saved(path):
|
|
193
|
+
warnings.warn(
|
|
194
|
+
"The way your dataset was saved is deprecated by now. "
|
|
195
|
+
"Please save it again using dataset.save().",
|
|
196
|
+
UserWarning,
|
|
197
|
+
)
|
|
198
|
+
return _outdated_load_concat_dataset(
|
|
199
|
+
path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# else we have a dataset saved in the new way with subdirectories in path
|
|
203
|
+
# for every dataset with description.json and -epo.fif or -raw.fif,
|
|
204
|
+
# target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
|
|
205
|
+
# window_preproc_kwargs.json
|
|
206
|
+
if ids_to_load is None:
|
|
207
|
+
ids_to_load = [p.name for p in path.iterdir()]
|
|
208
|
+
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
209
|
+
ids_to_load = [str(i) for i in ids_to_load]
|
|
210
|
+
first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
|
|
211
|
+
is_raw = first_raw_fif_path.exists()
|
|
212
|
+
metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
|
|
213
|
+
has_stored_windows = metadata_path.exists()
|
|
214
|
+
|
|
215
|
+
# Parallelization of mne.read_epochs with preload=False fails with
|
|
216
|
+
# 'TypeError: cannot pickle '_io.BufferedReader' object'.
|
|
217
|
+
# So ignore n_jobs in that case and load with a single job.
|
|
218
|
+
if not is_raw and n_jobs != 1:
|
|
219
|
+
warnings.warn(
|
|
220
|
+
"Parallelized reading with `preload=False` is not supported for "
|
|
221
|
+
"windowed data. Will use `n_jobs=1`.",
|
|
222
|
+
UserWarning,
|
|
223
|
+
)
|
|
224
|
+
n_jobs = 1
|
|
225
|
+
datasets = Parallel(n_jobs)(
|
|
226
|
+
delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
|
|
227
|
+
for i in ids_to_load
|
|
228
|
+
)
|
|
229
|
+
return BaseConcatDataset(datasets)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
233
|
+
sub_dir = path / i
|
|
234
|
+
file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
|
|
235
|
+
if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
|
|
236
|
+
raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
|
|
237
|
+
|
|
238
|
+
fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
|
|
239
|
+
fif_file_name = fif_name_pattern.format(i)
|
|
240
|
+
fif_file_path = sub_dir / fif_file_name
|
|
241
|
+
|
|
242
|
+
signals = _load_signals(fif_file_path, preload, is_raw)
|
|
243
|
+
|
|
244
|
+
description_file_path = sub_dir / "description.json"
|
|
245
|
+
description = pd.read_json(description_file_path, typ="series")
|
|
246
|
+
|
|
247
|
+
target_file_path = sub_dir / "target_name.json"
|
|
248
|
+
target_name = None
|
|
249
|
+
if target_file_path.exists():
|
|
250
|
+
target_name = json.load(open(target_file_path, "r"))["target_name"]
|
|
251
|
+
|
|
252
|
+
if is_raw and (not has_stored_windows):
|
|
253
|
+
dataset = BaseDataset(signals, description, target_name)
|
|
254
|
+
else:
|
|
255
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
256
|
+
windows_ds_kwargs = [
|
|
257
|
+
kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
|
|
258
|
+
]
|
|
259
|
+
windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
|
|
260
|
+
if is_raw:
|
|
261
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
262
|
+
dataset = EEGWindowsDataset(
|
|
263
|
+
signals,
|
|
264
|
+
metadata=metadata,
|
|
265
|
+
description=description,
|
|
266
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
267
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
# MNE epochs dataset
|
|
271
|
+
dataset = WindowsDataset(
|
|
272
|
+
signals,
|
|
273
|
+
description,
|
|
274
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
275
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
276
|
+
)
|
|
277
|
+
setattr(dataset, "window_kwargs", window_kwargs)
|
|
278
|
+
for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
|
|
279
|
+
kwargs = _load_kwargs_json(kwargs_name, sub_dir)
|
|
280
|
+
setattr(dataset, kwargs_name, kwargs)
|
|
281
|
+
return dataset
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _load_kwargs_json(kwargs_name, sub_dir):
|
|
285
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
286
|
+
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
287
|
+
if os.path.exists(kwargs_file_path):
|
|
288
|
+
kwargs = json.load(open(kwargs_file_path, "r"))
|
|
289
|
+
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
290
|
+
return kwargs
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _is_outdated_saved(path):
|
|
294
|
+
"""Data was saved in the old way if there are 'description.json', '-raw.fif'
|
|
295
|
+
or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
|
|
296
|
+
than 'description.json' files."""
|
|
297
|
+
description_files = glob(os.path.join(path, "**/description.json"))
|
|
298
|
+
fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
|
|
299
|
+
os.path.join(path, "**/*-epo.fif")
|
|
300
|
+
)
|
|
301
|
+
multiple = len(description_files) != len(fif_files)
|
|
302
|
+
kwargs_in_path = any(
|
|
303
|
+
[
|
|
304
|
+
os.path.exists(os.path.join(path, kwarg_name))
|
|
305
|
+
for kwarg_name in [
|
|
306
|
+
"raw_preproc_kwargs",
|
|
307
|
+
"window_kwargs",
|
|
308
|
+
"window_preproc_kwargs",
|
|
309
|
+
]
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
return (
|
|
313
|
+
os.path.exists(os.path.join(path, "description.json"))
|
|
314
|
+
or os.path.exists(os.path.join(path, "0-raw.fif"))
|
|
315
|
+
or os.path.exists(os.path.join(path, "0-epo.fif"))
|
|
316
|
+
or multiple
|
|
317
|
+
or kwargs_in_path
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _check_save_dir_empty(save_dir):
|
|
322
|
+
"""Make sure a BaseConcatDataset can be saved under a given directory.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
save_dir : str
|
|
327
|
+
Directory under which a `BaseConcatDataset` will be saved.
|
|
328
|
+
|
|
329
|
+
Raises
|
|
330
|
+
-------
|
|
331
|
+
FileExistsError
|
|
332
|
+
If ``save_dir`` is not a valid directory for saving.
|
|
333
|
+
"""
|
|
334
|
+
sub_dirs = [
|
|
335
|
+
os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
|
|
336
|
+
]
|
|
337
|
+
if any(sub_dirs):
|
|
338
|
+
raise FileExistsError(
|
|
339
|
+
f"Directory {save_dir} already contains subdirectories. Please "
|
|
340
|
+
"select a different directory, set overwrite=True, or resolve "
|
|
341
|
+
"manually."
|
|
342
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def ms_to_samples(ms, fs):
|
|
7
|
+
"""
|
|
8
|
+
Compute milliseconds to number of samples.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
ms: number
|
|
13
|
+
Milliseconds
|
|
14
|
+
fs: number
|
|
15
|
+
Sampling rate
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
n_samples: int
|
|
20
|
+
Number of samples
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
return ms * fs / 1000.0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def samples_to_ms(n_samples, fs):
|
|
27
|
+
"""
|
|
28
|
+
Compute milliseconds to number of samples.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
n_samples: number
|
|
33
|
+
Number of samples
|
|
34
|
+
fs: number
|
|
35
|
+
Sampling rate
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
milliseconds: int
|
|
40
|
+
"""
|
|
41
|
+
return n_samples * 1000.0 / fs
|
braindecode/eegneuralnet.py
CHANGED
|
@@ -5,32 +5,36 @@
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
import abc
|
|
8
|
-
import logging
|
|
9
8
|
import inspect
|
|
9
|
+
import logging
|
|
10
10
|
|
|
11
11
|
import mne
|
|
12
12
|
import numpy as np
|
|
13
13
|
import torch
|
|
14
|
-
from skorch import NeuralNet
|
|
15
14
|
from sklearn.metrics import get_scorer
|
|
15
|
+
from skorch import NeuralNet
|
|
16
16
|
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
|
|
17
|
-
from skorch.
|
|
17
|
+
from skorch.helper import SliceDataset
|
|
18
|
+
from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
|
|
18
19
|
|
|
19
|
-
from .training.scoring import (CroppedTimeSeriesEpochScoring,
|
|
20
|
-
CroppedTrialEpochScoring, PostEpochTrainScoring)
|
|
21
|
-
from .models.util import models_dict
|
|
22
20
|
from .datasets.base import BaseConcatDataset, WindowsDataset
|
|
21
|
+
from .models.util import models_dict
|
|
22
|
+
from .training.scoring import (
|
|
23
|
+
CroppedTimeSeriesEpochScoring,
|
|
24
|
+
CroppedTrialEpochScoring,
|
|
25
|
+
PostEpochTrainScoring,
|
|
26
|
+
)
|
|
23
27
|
|
|
24
28
|
log = logging.getLogger(__name__)
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
def _get_model(model):
|
|
28
|
-
|
|
31
|
+
def _get_model(model: str):
|
|
32
|
+
"""Returns the corresponding class in case the model passed is a string."""
|
|
29
33
|
if isinstance(model, str):
|
|
30
34
|
if model in models_dict:
|
|
31
35
|
model = models_dict[model]
|
|
32
36
|
else:
|
|
33
|
-
raise ValueError(f
|
|
37
|
+
raise ValueError(f"Unknown model name {model!r}.")
|
|
34
38
|
return model
|
|
35
39
|
|
|
36
40
|
|
|
@@ -50,7 +54,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
50
54
|
will be left as is.
|
|
51
55
|
|
|
52
56
|
"""
|
|
53
|
-
kwargs = self.get_params_for(
|
|
57
|
+
kwargs = self.get_params_for("module")
|
|
54
58
|
module = _get_model(self.module)
|
|
55
59
|
module = self.initialized_instance(module, kwargs)
|
|
56
60
|
# pylint: disable=attribute-defined-outside-init
|
|
@@ -61,7 +65,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
61
65
|
# Here we parse the callbacks supplied as strings,
|
|
62
66
|
# e.g. 'accuracy', to the callbacks skorch expects
|
|
63
67
|
for name, cb, named_by_user in super()._yield_callbacks():
|
|
64
|
-
if name ==
|
|
68
|
+
if name == "str":
|
|
65
69
|
train_cb, valid_cb = self._parse_str_callback(cb)
|
|
66
70
|
yield train_cb
|
|
67
71
|
if self.train_split is not None:
|
|
@@ -72,15 +76,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
72
76
|
def _parse_str_callback(self, cb_supplied_name):
|
|
73
77
|
scoring = get_scorer(cb_supplied_name)
|
|
74
78
|
scoring_name = scoring._score_func.__name__
|
|
75
|
-
assert scoring_name.endswith(
|
|
76
|
-
|
|
77
|
-
if (scoring_name.endswith('_score') or
|
|
78
|
-
cb_supplied_name.startswith('neg_')):
|
|
79
|
+
assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
|
|
80
|
+
if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
|
|
79
81
|
lower_is_better = False
|
|
80
82
|
else:
|
|
81
83
|
lower_is_better = True
|
|
82
|
-
train_name = f
|
|
83
|
-
valid_name = f
|
|
84
|
+
train_name = f"train_{cb_supplied_name}"
|
|
85
|
+
valid_name = f"valid_{cb_supplied_name}"
|
|
84
86
|
if self.cropped:
|
|
85
87
|
train_scoring = CroppedTrialEpochScoring(
|
|
86
88
|
cb_supplied_name, lower_is_better, on_train=True, name=train_name
|
|
@@ -98,7 +100,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
98
100
|
named_by_user = True
|
|
99
101
|
train_valid_callbacks = [
|
|
100
102
|
(train_name, train_scoring, named_by_user),
|
|
101
|
-
(valid_name, valid_scoring, named_by_user)
|
|
103
|
+
(valid_name, valid_scoring, named_by_user),
|
|
102
104
|
]
|
|
103
105
|
return train_valid_callbacks
|
|
104
106
|
|
|
@@ -108,8 +110,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
108
110
|
if not training:
|
|
109
111
|
epoch_cbs = []
|
|
110
112
|
for name, cb in self.callbacks_:
|
|
111
|
-
if
|
|
112
|
-
|
|
113
|
+
if (
|
|
114
|
+
isinstance(
|
|
115
|
+
cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
|
|
116
|
+
)
|
|
117
|
+
and (hasattr(cb, "window_inds_"))
|
|
118
|
+
and (not cb.on_train)
|
|
119
|
+
):
|
|
113
120
|
epoch_cbs.append(cb)
|
|
114
121
|
# for trialwise decoding stuffs it might also be we don't have
|
|
115
122
|
# cropped loader, so no indices there
|
|
@@ -136,8 +143,11 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
136
143
|
i_window_stops = np.concatenate(i_window_stops)
|
|
137
144
|
window_ys = np.concatenate(window_ys)
|
|
138
145
|
return dict(
|
|
139
|
-
preds=preds,
|
|
140
|
-
|
|
146
|
+
preds=preds,
|
|
147
|
+
i_window_in_trials=i_window_in_trials,
|
|
148
|
+
i_window_stops=i_window_stops,
|
|
149
|
+
window_ys=window_ys,
|
|
150
|
+
)
|
|
141
151
|
|
|
142
152
|
# Changes the default target extractor to noop
|
|
143
153
|
@property
|
|
@@ -156,7 +166,9 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
156
166
|
(
|
|
157
167
|
"valid_loss",
|
|
158
168
|
BatchScoring(
|
|
159
|
-
valid_loss_score,
|
|
169
|
+
valid_loss_score,
|
|
170
|
+
name="valid_loss",
|
|
171
|
+
target_extractor=noop,
|
|
160
172
|
),
|
|
161
173
|
),
|
|
162
174
|
("print_log", PrintLog()),
|
|
@@ -179,17 +191,27 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
179
191
|
return
|
|
180
192
|
# get kwargs from signal:
|
|
181
193
|
signal_kwargs = dict()
|
|
182
|
-
|
|
194
|
+
# Using shape to work both with torch.tensor and numpy.array:
|
|
195
|
+
if (
|
|
196
|
+
isinstance(X, mne.BaseEpochs)
|
|
197
|
+
or (hasattr(X, "shape") and len(X.shape) >= 2)
|
|
198
|
+
or isinstance(X, SliceDataset)
|
|
199
|
+
):
|
|
183
200
|
if y is None:
|
|
184
|
-
raise ValueError("y must be specified if X is
|
|
185
|
-
signal_kwargs[
|
|
201
|
+
raise ValueError("y must be specified if X is array-like.")
|
|
202
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
|
|
186
203
|
if isinstance(X, mne.BaseEpochs):
|
|
187
204
|
self.log.info("Using mne.Epochs to find signal-related parameters.")
|
|
188
205
|
signal_kwargs["n_times"] = len(X.times)
|
|
189
|
-
signal_kwargs["sfreq"] = X.info[
|
|
190
|
-
signal_kwargs["chs_info"] = X.info[
|
|
206
|
+
signal_kwargs["sfreq"] = X.info["sfreq"]
|
|
207
|
+
signal_kwargs["chs_info"] = X.info["chs"]
|
|
208
|
+
elif isinstance(X, SliceDataset):
|
|
209
|
+
self.log.info("Using SliceDataset to find signal-related parameters.")
|
|
210
|
+
Xshape = X[0].shape
|
|
211
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
212
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
191
213
|
else:
|
|
192
|
-
self.log.info("Using
|
|
214
|
+
self.log.info("Using array-like to find signal-related parameters.")
|
|
193
215
|
signal_kwargs["n_times"] = X.shape[-1]
|
|
194
216
|
signal_kwargs["n_chans"] = X.shape[-2]
|
|
195
217
|
elif is_dataset(X):
|
|
@@ -198,21 +220,17 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
198
220
|
Xshape = X0.shape
|
|
199
221
|
signal_kwargs["n_times"] = Xshape[-1]
|
|
200
222
|
signal_kwargs["n_chans"] = Xshape[-2]
|
|
201
|
-
if (
|
|
202
|
-
|
|
203
|
-
all(ds.targets_from == 'metadata' for ds in X.datasets)
|
|
223
|
+
if isinstance(X, BaseConcatDataset) and all(
|
|
224
|
+
ds.targets_from == "metadata" for ds in X.datasets
|
|
204
225
|
):
|
|
205
226
|
y_target = X.get_metadata().target
|
|
206
|
-
signal_kwargs[
|
|
207
|
-
elif (
|
|
208
|
-
isinstance(X, WindowsDataset) and
|
|
209
|
-
X.targets_from == "metadata"
|
|
210
|
-
):
|
|
227
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
228
|
+
elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
|
|
211
229
|
y_target = X.windows.metadata.target
|
|
212
|
-
signal_kwargs[
|
|
230
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
213
231
|
else:
|
|
214
232
|
self.log.warning(
|
|
215
|
-
"Can only infer signal shape of
|
|
233
|
+
"Can only infer signal shape of array-like and Datasets, "
|
|
216
234
|
f"got {type(X)!r}."
|
|
217
235
|
)
|
|
218
236
|
return
|
|
@@ -227,15 +245,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
227
245
|
if k in all_module_kwargs:
|
|
228
246
|
module_kwargs[k] = v
|
|
229
247
|
else:
|
|
230
|
-
self.log.warning(
|
|
231
|
-
f"Module {self.module!r} "
|
|
232
|
-
f"is missing parameter {k!r}."
|
|
233
|
-
)
|
|
248
|
+
self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
|
|
234
249
|
|
|
235
250
|
# save kwargs to self:
|
|
236
251
|
self.log.info(
|
|
237
252
|
f"Passing additional parameters {module_kwargs!r} "
|
|
238
|
-
f"to module {self.module!r}."
|
|
253
|
+
f"to module {self.module!r}."
|
|
254
|
+
)
|
|
239
255
|
module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
|
|
240
256
|
self.set_params(**module_kwargs)
|
|
241
257
|
|
|
@@ -275,7 +291,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
275
291
|
|
|
276
292
|
"""
|
|
277
293
|
if isinstance(X, mne.BaseEpochs):
|
|
278
|
-
X = X.get_data(units=
|
|
294
|
+
X = X.get_data(units="uV")
|
|
279
295
|
return super().get_dataset(X, y)
|
|
280
296
|
|
|
281
297
|
def partial_fit(self, X, y=None, classes=None, **fit_params):
|
|
@@ -291,7 +307,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
291
307
|
|
|
292
308
|
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
293
309
|
``sfreq``, ``input_window_seconds``
|
|
294
|
-
*
|
|
310
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
295
311
|
* WindowsDataset with ``targets_from='metadata'``
|
|
296
312
|
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
297
313
|
* other Dataset: ``n_times``, ``n_chans``
|
|
@@ -345,7 +361,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
345
361
|
|
|
346
362
|
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
347
363
|
``sfreq``, ``input_window_seconds``
|
|
348
|
-
*
|
|
364
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
349
365
|
* WindowsDataset with ``targets_from='metadata'``
|
|
350
366
|
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
351
367
|
* other Dataset: ``n_times``, ``n_chans``
|