eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +4 -1
- eegdash/data_config.py +28 -0
- eegdash/data_utils.py +193 -148
- eegdash/features/__init__.py +25 -0
- eegdash/features/datasets.py +456 -0
- eegdash/features/decorators.py +43 -0
- eegdash/features/extractors.py +210 -0
- eegdash/features/feature_bank/__init__.py +6 -0
- eegdash/features/feature_bank/complexity.py +96 -0
- eegdash/features/feature_bank/connectivity.py +59 -0
- eegdash/features/feature_bank/csp.py +101 -0
- eegdash/features/feature_bank/dimensionality.py +107 -0
- eegdash/features/feature_bank/signal.py +103 -0
- eegdash/features/feature_bank/spectral.py +116 -0
- eegdash/features/feature_bank/utils.py +48 -0
- eegdash/features/serialization.py +87 -0
- eegdash/features/utils.py +116 -0
- eegdash/main.py +250 -145
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/METADATA +26 -56
- eegdash-0.1.0.dist-info/RECORD +23 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/WHEEL +1 -1
- eegdash-0.0.8.dist-info/RECORD +0 -8
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions for storing and loading of features datasets.
|
|
3
|
+
|
|
4
|
+
see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from joblib import Parallel, delayed
|
|
12
|
+
from mne.io import read_info
|
|
13
|
+
|
|
14
|
+
from braindecode.datautil.serialization import _load_kwargs_json
|
|
15
|
+
|
|
16
|
+
from .datasets import (
|
|
17
|
+
FeaturesConcatDataset,
|
|
18
|
+
FeaturesDataset,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
|
|
23
|
+
"""Load a stored FeaturesConcatDataset of FeaturesDatasets from files.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
path: str | pathlib.Path
|
|
28
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
29
|
+
ids_to_load: list of int | None
|
|
30
|
+
Ids of specific files to load.
|
|
31
|
+
n_jobs: int
|
|
32
|
+
Number of jobs to be used to read files in parallel.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
concat_dataset: FeaturesConcatDataset of FeaturesDatasets
|
|
37
|
+
"""
|
|
38
|
+
# Make sure we always work with a pathlib.Path
|
|
39
|
+
path = Path(path)
|
|
40
|
+
|
|
41
|
+
# else we have a dataset saved in the new way with subdirectories in path
|
|
42
|
+
# for every dataset with description.json and -feat.parquet,
|
|
43
|
+
# target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
|
|
44
|
+
# window_preproc_kwargs.json, features_kwargs.json
|
|
45
|
+
if ids_to_load is None:
|
|
46
|
+
ids_to_load = [p.name for p in path.iterdir()]
|
|
47
|
+
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
48
|
+
ids_to_load = [str(i) for i in ids_to_load]
|
|
49
|
+
|
|
50
|
+
datasets = Parallel(n_jobs)(delayed(_load_parallel)(path, i) for i in ids_to_load)
|
|
51
|
+
return FeaturesConcatDataset(datasets)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _load_parallel(path, i):
|
|
55
|
+
sub_dir = path / i
|
|
56
|
+
|
|
57
|
+
parquet_name_pattern = "{}-feat.parquet"
|
|
58
|
+
parquet_file_name = parquet_name_pattern.format(i)
|
|
59
|
+
parquet_file_path = sub_dir / parquet_file_name
|
|
60
|
+
|
|
61
|
+
features = pd.read_parquet(parquet_file_path)
|
|
62
|
+
|
|
63
|
+
description_file_path = sub_dir / "description.json"
|
|
64
|
+
description = pd.read_json(description_file_path, typ="series")
|
|
65
|
+
|
|
66
|
+
raw_info_file_path = sub_dir / "raw-info.fif"
|
|
67
|
+
raw_info = None
|
|
68
|
+
if raw_info_file_path.exists():
|
|
69
|
+
raw_info = read_info(raw_info_file_path)
|
|
70
|
+
|
|
71
|
+
raw_preproc_kwargs = _load_kwargs_json("raw_preproc_kwargs", sub_dir)
|
|
72
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
73
|
+
window_preproc_kwargs = _load_kwargs_json("window_preproc_kwargs", sub_dir)
|
|
74
|
+
features_kwargs = _load_kwargs_json("features_kwargs", sub_dir)
|
|
75
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
76
|
+
|
|
77
|
+
dataset = FeaturesDataset(
|
|
78
|
+
features,
|
|
79
|
+
metadata=metadata,
|
|
80
|
+
description=description,
|
|
81
|
+
raw_info=raw_info,
|
|
82
|
+
raw_preproc_kwargs=raw_preproc_kwargs,
|
|
83
|
+
window_kwargs=window_kwargs,
|
|
84
|
+
window_preproc_kwargs=window_preproc_kwargs,
|
|
85
|
+
features_kwargs=features_kwargs,
|
|
86
|
+
)
|
|
87
|
+
return dataset
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from joblib import Parallel, delayed
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from braindecode.datasets.base import (
|
|
12
|
+
BaseConcatDataset,
|
|
13
|
+
EEGWindowsDataset,
|
|
14
|
+
WindowsDataset,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .datasets import FeaturesConcatDataset, FeaturesDataset
|
|
18
|
+
from .extractors import FeatureExtractor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _extract_features_from_windowsdataset(
|
|
22
|
+
win_ds: EEGWindowsDataset | WindowsDataset,
|
|
23
|
+
feature_extractor: FeatureExtractor,
|
|
24
|
+
batch_size: int = 512,
|
|
25
|
+
):
|
|
26
|
+
metadata = win_ds.metadata
|
|
27
|
+
if not win_ds.targets_from == "metadata":
|
|
28
|
+
metadata = copy.deepcopy(metadata)
|
|
29
|
+
metadata["orig_index"] = metadata.index
|
|
30
|
+
metadata.set_index(
|
|
31
|
+
["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"],
|
|
32
|
+
drop=False,
|
|
33
|
+
inplace=True,
|
|
34
|
+
)
|
|
35
|
+
win_dl = DataLoader(win_ds, batch_size=batch_size, shuffle=False, drop_last=False)
|
|
36
|
+
features_dict = dict()
|
|
37
|
+
ch_names = win_ds.raw.ch_names
|
|
38
|
+
for X, y, crop_inds in win_dl:
|
|
39
|
+
X = X.numpy()
|
|
40
|
+
if hasattr(y, "tolist"):
|
|
41
|
+
y = y.tolist()
|
|
42
|
+
win_dict = dict()
|
|
43
|
+
win_dict.update(
|
|
44
|
+
feature_extractor(X, _batch_size=X.shape[0], _ch_names=ch_names)
|
|
45
|
+
)
|
|
46
|
+
if not win_ds.targets_from == "metadata":
|
|
47
|
+
metadata.loc[crop_inds, "target"] = y
|
|
48
|
+
for k, v in win_dict.items():
|
|
49
|
+
if k not in features_dict:
|
|
50
|
+
features_dict[k] = []
|
|
51
|
+
features_dict[k].extend(v)
|
|
52
|
+
features_df = pd.DataFrame(features_dict)
|
|
53
|
+
if not win_ds.targets_from == "metadata":
|
|
54
|
+
metadata.set_index("orig_index", drop=False, inplace=True)
|
|
55
|
+
metadata.reset_index(drop=True, inplace=True)
|
|
56
|
+
metadata.drop("orig_index", axis=1, inplace=True)
|
|
57
|
+
|
|
58
|
+
# FUTURE: truly support WindowsDataset objects
|
|
59
|
+
return FeaturesDataset(
|
|
60
|
+
features_df,
|
|
61
|
+
metadata=metadata,
|
|
62
|
+
description=win_ds.description,
|
|
63
|
+
raw_info=win_ds.raw.info,
|
|
64
|
+
raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
|
|
65
|
+
window_kwargs=win_ds.window_kwargs,
|
|
66
|
+
features_kwargs=feature_extractor.features_kwargs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def extract_features(
|
|
71
|
+
concat_dataset: BaseConcatDataset,
|
|
72
|
+
features: FeatureExtractor | Dict[str, Callable] | List[Callable],
|
|
73
|
+
*,
|
|
74
|
+
batch_size: int = 512,
|
|
75
|
+
n_jobs: int = 1,
|
|
76
|
+
):
|
|
77
|
+
if isinstance(features, list):
|
|
78
|
+
features = dict(enumerate(features))
|
|
79
|
+
if not isinstance(features, FeatureExtractor):
|
|
80
|
+
features = FeatureExtractor(features)
|
|
81
|
+
feature_ds_list = list(
|
|
82
|
+
tqdm(
|
|
83
|
+
Parallel(n_jobs=n_jobs, return_as="generator")(
|
|
84
|
+
delayed(_extract_features_from_windowsdataset)(
|
|
85
|
+
win_ds, features, batch_size
|
|
86
|
+
)
|
|
87
|
+
for win_ds in concat_dataset.datasets
|
|
88
|
+
),
|
|
89
|
+
total=len(concat_dataset.datasets),
|
|
90
|
+
desc="Extracting features",
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
return FeaturesConcatDataset(feature_ds_list)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def fit_feature_extractors(
|
|
97
|
+
concat_dataset: BaseConcatDataset,
|
|
98
|
+
features: FeatureExtractor | Dict[str, Callable] | List[Callable],
|
|
99
|
+
batch_size: int = 8192,
|
|
100
|
+
):
|
|
101
|
+
if isinstance(features, list):
|
|
102
|
+
features = dict(enumerate(features))
|
|
103
|
+
if not isinstance(features, FeatureExtractor):
|
|
104
|
+
features = FeatureExtractor(features)
|
|
105
|
+
if not features._is_fitable:
|
|
106
|
+
return features
|
|
107
|
+
features.clear()
|
|
108
|
+
concat_dl = DataLoader(
|
|
109
|
+
concat_dataset, batch_size=batch_size, shuffle=False, drop_last=False
|
|
110
|
+
)
|
|
111
|
+
for X, y, _ in tqdm(
|
|
112
|
+
concat_dl, total=len(concat_dl), desc="Fitting feature extractors"
|
|
113
|
+
):
|
|
114
|
+
features.partial_fit(X.numpy(), y=np.array(y))
|
|
115
|
+
features.fit()
|
|
116
|
+
return features
|