braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""Convenience functions for storing and loading of windows datasets."""
|
|
2
|
+
|
|
3
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
4
|
+
#
|
|
5
|
+
# License: BSD (3-clause)
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import pickle
|
|
10
|
+
import warnings
|
|
11
|
+
from glob import glob
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import mne
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from joblib import Parallel, delayed
|
|
17
|
+
|
|
18
|
+
from ..datasets.base import (
|
|
19
|
+
BaseConcatDataset,
|
|
20
|
+
EEGWindowsDataset,
|
|
21
|
+
RawDataset,
|
|
22
|
+
WindowsDataset,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
27
|
+
warnings.warn(
|
|
28
|
+
'"save_concat_dataset()" is deprecated and will be removed in'
|
|
29
|
+
" the future. Use dataset.save() instead.",
|
|
30
|
+
UserWarning,
|
|
31
|
+
)
|
|
32
|
+
concat_dataset.save(path=path, overwrite=overwrite)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
36
|
+
"""Load a stored BaseConcatDataset from.
|
|
37
|
+
|
|
38
|
+
files.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
path : pathlib.Path
|
|
43
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
44
|
+
preload : bool
|
|
45
|
+
Whether to preload the data.
|
|
46
|
+
ids_to_load : None | list(int)
|
|
47
|
+
Ids of specific files to load.
|
|
48
|
+
target_name : None or str
|
|
49
|
+
Load specific description column as target. If not given, take saved
|
|
50
|
+
target name.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
concat_dataset : BaseConcatDataset
|
|
55
|
+
"""
|
|
56
|
+
# assume we have a single concat dataset to load
|
|
57
|
+
is_raw = (path / "0-raw.fif").is_file()
|
|
58
|
+
assert not (not is_raw and target_name is not None), (
|
|
59
|
+
"Setting a new target is only supported for raws."
|
|
60
|
+
)
|
|
61
|
+
is_epochs = (path / "0-epo.fif").is_file()
|
|
62
|
+
paths = [path]
|
|
63
|
+
# assume we have multiple concat datasets to load
|
|
64
|
+
if not (is_raw or is_epochs):
|
|
65
|
+
is_raw = (path / "0" / "0-raw.fif").is_file()
|
|
66
|
+
is_epochs = (path / "0" / "0-epo.fif").is_file()
|
|
67
|
+
paths = path.glob("*/")
|
|
68
|
+
paths = sorted(paths, key=lambda p: int(p.name))
|
|
69
|
+
if ids_to_load is not None:
|
|
70
|
+
paths = [paths[i] for i in ids_to_load]
|
|
71
|
+
ids_to_load = None
|
|
72
|
+
# if we have neither a single nor multiple datasets, something went wrong
|
|
73
|
+
assert is_raw or is_epochs, (
|
|
74
|
+
f"Expect either raw or epo to exist in {path} or in {path / '0'}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
datasets = []
|
|
78
|
+
for path in paths:
|
|
79
|
+
if is_raw and target_name is None:
|
|
80
|
+
target_file_name = path / "target_name.json"
|
|
81
|
+
target_name = json.load(open(target_file_name, "r"))["target_name"]
|
|
82
|
+
|
|
83
|
+
all_signals, description = _load_signals_and_description(
|
|
84
|
+
path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
|
|
85
|
+
)
|
|
86
|
+
for i_signal, signal in enumerate(all_signals):
|
|
87
|
+
if is_raw:
|
|
88
|
+
datasets.append(
|
|
89
|
+
RawDataset(
|
|
90
|
+
signal, description.iloc[i_signal], target_name=target_name
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
|
|
95
|
+
concat_ds = BaseConcatDataset(datasets)
|
|
96
|
+
for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
|
|
97
|
+
kwarg_path = path / ".".join([kwarg_name, "json"])
|
|
98
|
+
if kwarg_path.exists():
|
|
99
|
+
with open(kwarg_path, "r") as f:
|
|
100
|
+
kwargs = json.load(f)
|
|
101
|
+
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
102
|
+
setattr(concat_ds, kwarg_name, kwargs)
|
|
103
|
+
return concat_ds
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
|
|
107
|
+
all_signals = []
|
|
108
|
+
file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
|
|
109
|
+
description_df = pd.read_json(
|
|
110
|
+
path / "description.json", typ="series", convert_dates=False
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if "timestamp" in description_df.index:
|
|
114
|
+
timestamp_numeric = pd.to_numeric(description_df["timestamp"])
|
|
115
|
+
description_df["timestamp"] = pd.to_datetime(timestamp_numeric)
|
|
116
|
+
|
|
117
|
+
if ids_to_load is None:
|
|
118
|
+
file_names = path.glob(f"*{file_name.lstrip('{}')}")
|
|
119
|
+
# Extract ids, e.g.,
|
|
120
|
+
# '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
|
|
121
|
+
# '11-raw.fif' -> 11
|
|
122
|
+
ids_to_load = sorted(
|
|
123
|
+
[int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
|
|
124
|
+
)
|
|
125
|
+
for i in ids_to_load:
|
|
126
|
+
fif_file = path / file_name.format(i)
|
|
127
|
+
all_signals.append(_load_signals(fif_file, preload, is_raw))
|
|
128
|
+
description_df = description_df.iloc[ids_to_load]
|
|
129
|
+
return all_signals, description_df
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _load_signals(fif_file, preload, is_raw):
|
|
133
|
+
# Reading the raw file from pickle if it has been save before.
|
|
134
|
+
# The pickle file only contain the raw object without the data.
|
|
135
|
+
pkl_file = fif_file.with_suffix(".pkl")
|
|
136
|
+
if pkl_file.exists():
|
|
137
|
+
with open(pkl_file, "rb") as f:
|
|
138
|
+
signals = pickle.load(f)
|
|
139
|
+
|
|
140
|
+
if all(Path(f).exists() for f in signals.filenames):
|
|
141
|
+
if preload:
|
|
142
|
+
signals.load_data()
|
|
143
|
+
return signals
|
|
144
|
+
else: # This may happen if the file has been moved together with the pickle file.
|
|
145
|
+
warnings.warn(
|
|
146
|
+
f"Pickle file {pkl_file} exists, but the referenced fif "
|
|
147
|
+
"file(s) do not exist. Will read the fif file(s) directly "
|
|
148
|
+
"and re-create the pickle file.",
|
|
149
|
+
UserWarning,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# If pickle didn't exist read via mne (likely slower) and save pkl after
|
|
153
|
+
if is_raw:
|
|
154
|
+
signals = mne.io.read_raw_fif(fif_file, preload=preload)
|
|
155
|
+
elif fif_file.name.endswith("-epo.fif"):
|
|
156
|
+
signals = mne.read_epochs(fif_file, preload=preload)
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError("fif_file must end with raw.fif or epo.fif.")
|
|
159
|
+
|
|
160
|
+
# Only do this for raw objects. Epoch objects are not picklable as they
|
|
161
|
+
# hold references to open files in `signals._raw[0].fid`.
|
|
162
|
+
if is_raw:
|
|
163
|
+
# Saving the raw file without data into a pickle file, so it can be
|
|
164
|
+
# retrieved faster on the next use of this dataset.
|
|
165
|
+
with open(pkl_file, "wb") as f:
|
|
166
|
+
if preload:
|
|
167
|
+
data = signals._data
|
|
168
|
+
signals._data, signals.preload = None, False
|
|
169
|
+
pickle.dump(signals, f)
|
|
170
|
+
if preload:
|
|
171
|
+
signals._data, signals.preload = data, True
|
|
172
|
+
|
|
173
|
+
return signals
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
177
|
+
"""Load a stored BaseConcatDataset from.
|
|
178
|
+
|
|
179
|
+
files.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
path : str | pathlib.Path
|
|
184
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
185
|
+
preload : bool
|
|
186
|
+
Whether to preload the data.
|
|
187
|
+
ids_to_load : list of int | None
|
|
188
|
+
Ids of specific files to load.
|
|
189
|
+
target_name : str | list | None
|
|
190
|
+
Load specific description column as target. If not given, take saved
|
|
191
|
+
target name.
|
|
192
|
+
n_jobs : int
|
|
193
|
+
Number of jobs to be used to read files in parallel.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
concat_dataset : BaseConcatDataset
|
|
198
|
+
"""
|
|
199
|
+
# Make sure we always work with a pathlib.Path
|
|
200
|
+
path = Path(path)
|
|
201
|
+
|
|
202
|
+
# if we encounter a dataset that was saved in 'the old way', call the
|
|
203
|
+
# corresponding 'old' loading function
|
|
204
|
+
if _is_outdated_saved(path):
|
|
205
|
+
warnings.warn(
|
|
206
|
+
"The way your dataset was saved is deprecated by now. "
|
|
207
|
+
"Please save it again using dataset.save().",
|
|
208
|
+
UserWarning,
|
|
209
|
+
)
|
|
210
|
+
return _outdated_load_concat_dataset(
|
|
211
|
+
path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# else we have a dataset saved in the new way with subdirectories in path
|
|
215
|
+
# for every dataset with description.json and -epo.fif or -raw.fif,
|
|
216
|
+
# target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
|
|
217
|
+
# window_preproc_kwargs.json
|
|
218
|
+
if ids_to_load is None:
|
|
219
|
+
ids_to_load = [p.name for p in path.iterdir()]
|
|
220
|
+
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
221
|
+
ids_to_load = [str(i) for i in ids_to_load]
|
|
222
|
+
first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
|
|
223
|
+
is_raw = first_raw_fif_path.exists()
|
|
224
|
+
metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
|
|
225
|
+
has_stored_windows = metadata_path.exists()
|
|
226
|
+
|
|
227
|
+
# Parallelization of mne.read_epochs with preload=False fails with
|
|
228
|
+
# 'TypeError: cannot pickle '_io.BufferedReader' object'.
|
|
229
|
+
# So ignore n_jobs in that case and load with a single job.
|
|
230
|
+
if not is_raw and n_jobs != 1:
|
|
231
|
+
warnings.warn(
|
|
232
|
+
"Parallelized reading with `preload=False` is not supported for "
|
|
233
|
+
"windowed data. Will use `n_jobs=1`.",
|
|
234
|
+
UserWarning,
|
|
235
|
+
)
|
|
236
|
+
n_jobs = 1
|
|
237
|
+
datasets = Parallel(n_jobs)(
|
|
238
|
+
delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
|
|
239
|
+
for i in ids_to_load
|
|
240
|
+
)
|
|
241
|
+
return BaseConcatDataset(datasets)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
245
|
+
sub_dir = path / i
|
|
246
|
+
file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
|
|
247
|
+
if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
|
|
248
|
+
raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
|
|
249
|
+
|
|
250
|
+
fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
|
|
251
|
+
fif_file_name = fif_name_pattern.format(i)
|
|
252
|
+
fif_file_path = sub_dir / fif_file_name
|
|
253
|
+
|
|
254
|
+
signals = _load_signals(fif_file_path, preload, is_raw)
|
|
255
|
+
|
|
256
|
+
description_file_path = sub_dir / "description.json"
|
|
257
|
+
description = pd.read_json(description_file_path, typ="series", convert_dates=False)
|
|
258
|
+
|
|
259
|
+
# if 'timestamp' in description.index:
|
|
260
|
+
# timestamp_numeric = pd.to_numeric(description['timestamp'])
|
|
261
|
+
# description['timestamp'] = pd.to_datetime(timestamp_numeric, unit='s')
|
|
262
|
+
|
|
263
|
+
target_file_path = sub_dir / "target_name.json"
|
|
264
|
+
target_name = None
|
|
265
|
+
if target_file_path.exists():
|
|
266
|
+
target_name = json.load(open(target_file_path, "r"))["target_name"]
|
|
267
|
+
|
|
268
|
+
if is_raw and (not has_stored_windows):
|
|
269
|
+
dataset = RawDataset(signals, description, target_name)
|
|
270
|
+
else:
|
|
271
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
272
|
+
windows_ds_kwargs = [
|
|
273
|
+
kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
|
|
274
|
+
]
|
|
275
|
+
windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
|
|
276
|
+
if is_raw:
|
|
277
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
278
|
+
dataset = EEGWindowsDataset(
|
|
279
|
+
signals,
|
|
280
|
+
metadata=metadata,
|
|
281
|
+
description=description,
|
|
282
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
283
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
# MNE epochs dataset
|
|
287
|
+
dataset = WindowsDataset(
|
|
288
|
+
signals,
|
|
289
|
+
description,
|
|
290
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
291
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
292
|
+
)
|
|
293
|
+
setattr(dataset, "window_kwargs", window_kwargs)
|
|
294
|
+
for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
|
|
295
|
+
kwargs = _load_kwargs_json(kwargs_name, sub_dir)
|
|
296
|
+
setattr(dataset, kwargs_name, kwargs)
|
|
297
|
+
return dataset
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _load_kwargs_json(kwargs_name, sub_dir):
|
|
301
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
302
|
+
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
303
|
+
if os.path.exists(kwargs_file_path):
|
|
304
|
+
kwargs = json.load(open(kwargs_file_path, "r"))
|
|
305
|
+
return kwargs
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _is_outdated_saved(path):
|
|
309
|
+
"""Data was saved in the old way if there are 'description.json', '-raw.fif'.
|
|
310
|
+
|
|
311
|
+
or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
|
|
312
|
+
than 'description.json' files.
|
|
313
|
+
"""
|
|
314
|
+
description_files = glob(os.path.join(path, "**/description.json"))
|
|
315
|
+
fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
|
|
316
|
+
os.path.join(path, "**/*-epo.fif")
|
|
317
|
+
)
|
|
318
|
+
multiple = len(description_files) != len(fif_files)
|
|
319
|
+
kwargs_in_path = any(
|
|
320
|
+
[
|
|
321
|
+
os.path.exists(os.path.join(path, kwarg_name))
|
|
322
|
+
for kwarg_name in [
|
|
323
|
+
"raw_preproc_kwargs",
|
|
324
|
+
"window_kwargs",
|
|
325
|
+
"window_preproc_kwargs",
|
|
326
|
+
]
|
|
327
|
+
]
|
|
328
|
+
)
|
|
329
|
+
return (
|
|
330
|
+
os.path.exists(os.path.join(path, "description.json"))
|
|
331
|
+
or os.path.exists(os.path.join(path, "0-raw.fif"))
|
|
332
|
+
or os.path.exists(os.path.join(path, "0-epo.fif"))
|
|
333
|
+
or multiple
|
|
334
|
+
or kwargs_in_path
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _check_save_dir_empty(save_dir):
|
|
339
|
+
"""Make sure a BaseConcatDataset can be saved under a given directory.
|
|
340
|
+
|
|
341
|
+
Parameters
|
|
342
|
+
----------
|
|
343
|
+
save_dir : str
|
|
344
|
+
Directory under which a `BaseConcatDataset` will be saved.
|
|
345
|
+
|
|
346
|
+
Raises
|
|
347
|
+
------
|
|
348
|
+
FileExistsError
|
|
349
|
+
If ``save_dir`` is not a valid directory for saving.
|
|
350
|
+
"""
|
|
351
|
+
sub_dirs = [
|
|
352
|
+
os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
|
|
353
|
+
]
|
|
354
|
+
if any(sub_dirs):
|
|
355
|
+
raise FileExistsError(
|
|
356
|
+
f"Directory {save_dir} already contains subdirectories. Please "
|
|
357
|
+
"select a different directory, set overwrite=True, or resolve "
|
|
358
|
+
"manually."
|
|
359
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
import mne
|
|
9
|
+
import numpy as np
|
|
10
|
+
from skorch.helper import SliceDataset
|
|
11
|
+
from skorch.utils import is_dataset
|
|
12
|
+
|
|
13
|
+
from braindecode.datasets.base import BaseConcatDataset, WindowsDataset
|
|
14
|
+
from braindecode.models.util import SigArgName
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def ms_to_samples(ms, fs):
|
|
20
|
+
"""
|
|
21
|
+
Compute milliseconds to number of samples.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
ms : number
|
|
26
|
+
Milliseconds
|
|
27
|
+
fs : number
|
|
28
|
+
Sampling rate
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
n_samples : int
|
|
33
|
+
Number of samples
|
|
34
|
+
"""
|
|
35
|
+
return ms * fs / 1000.0
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def samples_to_ms(n_samples, fs):
|
|
39
|
+
"""
|
|
40
|
+
Compute milliseconds to number of samples.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
n_samples : number
|
|
45
|
+
Number of samples
|
|
46
|
+
fs : number
|
|
47
|
+
Sampling rate
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
milliseconds : int
|
|
52
|
+
"""
|
|
53
|
+
return n_samples * 1000.0 / fs
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_n_outputs(y, classes, mode):
|
|
57
|
+
if mode == "classification":
|
|
58
|
+
classes_y = np.unique(y)
|
|
59
|
+
if classes is not None:
|
|
60
|
+
assert set(classes_y) <= set(classes)
|
|
61
|
+
else:
|
|
62
|
+
classes = classes_y
|
|
63
|
+
return len(classes)
|
|
64
|
+
elif mode == "regression":
|
|
65
|
+
if y is None:
|
|
66
|
+
return None
|
|
67
|
+
if y.ndim == 1:
|
|
68
|
+
return 1
|
|
69
|
+
else:
|
|
70
|
+
return y.shape[-1]
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Unknown mode {mode}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def infer_signal_properties(
|
|
76
|
+
X,
|
|
77
|
+
y=None,
|
|
78
|
+
mode: Literal["classification", "regression"] = "classification",
|
|
79
|
+
classes: list | None = None,
|
|
80
|
+
) -> dict[SigArgName, Any]:
|
|
81
|
+
"""Infers signal properties from the data.
|
|
82
|
+
|
|
83
|
+
The extracted signal properties are:
|
|
84
|
+
|
|
85
|
+
+ n_chans: number of channels
|
|
86
|
+
+ n_times: number of time points
|
|
87
|
+
+ n_outputs: number of outputs
|
|
88
|
+
+ chs_info: channel information
|
|
89
|
+
+ sfreq: sampling frequency
|
|
90
|
+
|
|
91
|
+
The returned dictionary can serve as kwargs for model initialization.
|
|
92
|
+
|
|
93
|
+
Depending on the type of input passed, not all properties can be inferred.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
X : array-like or mne.BaseEpochs or Dataset
|
|
98
|
+
Input data
|
|
99
|
+
y : array-like or None
|
|
100
|
+
Targets
|
|
101
|
+
mode : "classification" or "regression"
|
|
102
|
+
Mode of the task
|
|
103
|
+
classes : list or None
|
|
104
|
+
List of classes for classification
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
signal_kwargs : dict
|
|
109
|
+
Dictionary with signal-properties. Can serve as kwargs for model
|
|
110
|
+
initialization.
|
|
111
|
+
"""
|
|
112
|
+
signal_kwargs: dict[SigArgName, Any] = {}
|
|
113
|
+
# Using shape to work both with torch.tensor and numpy.array:
|
|
114
|
+
if (
|
|
115
|
+
isinstance(X, mne.BaseEpochs)
|
|
116
|
+
or (hasattr(X, "shape") and len(X.shape) >= 2)
|
|
117
|
+
or isinstance(X, SliceDataset)
|
|
118
|
+
):
|
|
119
|
+
if y is None:
|
|
120
|
+
raise ValueError("y must be specified if X is array-like.")
|
|
121
|
+
signal_kwargs["n_outputs"] = _get_n_outputs(y, classes, mode)
|
|
122
|
+
if isinstance(X, mne.BaseEpochs):
|
|
123
|
+
log.info("Using mne.Epochs to find signal-related parameters.")
|
|
124
|
+
signal_kwargs["n_times"] = len(X.times)
|
|
125
|
+
signal_kwargs["sfreq"] = X.info["sfreq"]
|
|
126
|
+
signal_kwargs["chs_info"] = X.info["chs"]
|
|
127
|
+
elif isinstance(X, SliceDataset):
|
|
128
|
+
log.info("Using SliceDataset to find signal-related parameters.")
|
|
129
|
+
Xshape = X[0].shape
|
|
130
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
131
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
132
|
+
else:
|
|
133
|
+
log.info("Using array-like to find signal-related parameters.")
|
|
134
|
+
signal_kwargs["n_times"] = X.shape[-1]
|
|
135
|
+
signal_kwargs["n_chans"] = X.shape[-2]
|
|
136
|
+
elif is_dataset(X):
|
|
137
|
+
log.info(f"Using Dataset {X!r} to find signal-related parameters.")
|
|
138
|
+
X0 = X[0][0]
|
|
139
|
+
Xshape = X0.shape
|
|
140
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
141
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
142
|
+
if isinstance(X, BaseConcatDataset) and all(
|
|
143
|
+
ds.targets_from == "metadata" for ds in X.datasets
|
|
144
|
+
):
|
|
145
|
+
y_target = X.get_metadata().target
|
|
146
|
+
signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
|
|
147
|
+
elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
|
|
148
|
+
y_target = X.windows.metadata.target
|
|
149
|
+
signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
|
|
150
|
+
else:
|
|
151
|
+
log.warning(
|
|
152
|
+
f"Can only infer signal shape of array-like and Datasets, got {type(X)!r}."
|
|
153
|
+
)
|
|
154
|
+
return signal_kwargs
|