eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,87 @@
1
+ """
2
+ Convenience functions for storing and loading of features datasets.
3
+
4
+ see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
5
+ """
6
+
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import pandas as pd
11
+ from joblib import Parallel, delayed
12
+ from mne.io import read_info
13
+
14
+ from braindecode.datautil.serialization import _load_kwargs_json
15
+
16
+ from .datasets import (
17
+ FeaturesConcatDataset,
18
+ FeaturesDataset,
19
+ )
20
+
21
+
22
+ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
23
+ """Load a stored FeaturesConcatDataset of FeaturesDatasets from files.
24
+
25
+ Parameters
26
+ ----------
27
+ path: str | pathlib.Path
28
+ Path to the directory of the .fif / -epo.fif and .json files.
29
+ ids_to_load: list of int | None
30
+ Ids of specific files to load.
31
+ n_jobs: int
32
+ Number of jobs to be used to read files in parallel.
33
+
34
+ Returns
35
+ -------
36
+ concat_dataset: FeaturesConcatDataset of FeaturesDatasets
37
+ """
38
+ # Make sure we always work with a pathlib.Path
39
+ path = Path(path)
40
+
41
+ # else we have a dataset saved in the new way with subdirectories in path
42
+ # for every dataset with description.json and -feat.parquet,
43
+ # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
44
+ # window_preproc_kwargs.json, features_kwargs.json
45
+ if ids_to_load is None:
46
+ ids_to_load = [p.name for p in path.iterdir()]
47
+ ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
48
+ ids_to_load = [str(i) for i in ids_to_load]
49
+
50
+ datasets = Parallel(n_jobs)(delayed(_load_parallel)(path, i) for i in ids_to_load)
51
+ return FeaturesConcatDataset(datasets)
52
+
53
+
54
+ def _load_parallel(path, i):
55
+ sub_dir = path / i
56
+
57
+ parquet_name_pattern = "{}-feat.parquet"
58
+ parquet_file_name = parquet_name_pattern.format(i)
59
+ parquet_file_path = sub_dir / parquet_file_name
60
+
61
+ features = pd.read_parquet(parquet_file_path)
62
+
63
+ description_file_path = sub_dir / "description.json"
64
+ description = pd.read_json(description_file_path, typ="series")
65
+
66
+ raw_info_file_path = sub_dir / "raw-info.fif"
67
+ raw_info = None
68
+ if raw_info_file_path.exists():
69
+ raw_info = read_info(raw_info_file_path)
70
+
71
+ raw_preproc_kwargs = _load_kwargs_json("raw_preproc_kwargs", sub_dir)
72
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
73
+ window_preproc_kwargs = _load_kwargs_json("window_preproc_kwargs", sub_dir)
74
+ features_kwargs = _load_kwargs_json("features_kwargs", sub_dir)
75
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
76
+
77
+ dataset = FeaturesDataset(
78
+ features,
79
+ metadata=metadata,
80
+ description=description,
81
+ raw_info=raw_info,
82
+ raw_preproc_kwargs=raw_preproc_kwargs,
83
+ window_kwargs=window_kwargs,
84
+ window_preproc_kwargs=window_preproc_kwargs,
85
+ features_kwargs=features_kwargs,
86
+ )
87
+ return dataset
@@ -0,0 +1,116 @@
1
+ import copy
2
+ from collections.abc import Callable
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from joblib import Parallel, delayed
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from braindecode.datasets.base import (
12
+ BaseConcatDataset,
13
+ EEGWindowsDataset,
14
+ WindowsDataset,
15
+ )
16
+
17
+ from .datasets import FeaturesConcatDataset, FeaturesDataset
18
+ from .extractors import FeatureExtractor
19
+
20
+
21
+ def _extract_features_from_windowsdataset(
22
+ win_ds: EEGWindowsDataset | WindowsDataset,
23
+ feature_extractor: FeatureExtractor,
24
+ batch_size: int = 512,
25
+ ):
26
+ metadata = win_ds.metadata
27
+ if not win_ds.targets_from == "metadata":
28
+ metadata = copy.deepcopy(metadata)
29
+ metadata["orig_index"] = metadata.index
30
+ metadata.set_index(
31
+ ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"],
32
+ drop=False,
33
+ inplace=True,
34
+ )
35
+ win_dl = DataLoader(win_ds, batch_size=batch_size, shuffle=False, drop_last=False)
36
+ features_dict = dict()
37
+ ch_names = win_ds.raw.ch_names
38
+ for X, y, crop_inds in win_dl:
39
+ X = X.numpy()
40
+ if hasattr(y, "tolist"):
41
+ y = y.tolist()
42
+ win_dict = dict()
43
+ win_dict.update(
44
+ feature_extractor(X, _batch_size=X.shape[0], _ch_names=ch_names)
45
+ )
46
+ if not win_ds.targets_from == "metadata":
47
+ metadata.loc[crop_inds, "target"] = y
48
+ for k, v in win_dict.items():
49
+ if k not in features_dict:
50
+ features_dict[k] = []
51
+ features_dict[k].extend(v)
52
+ features_df = pd.DataFrame(features_dict)
53
+ if not win_ds.targets_from == "metadata":
54
+ metadata.set_index("orig_index", drop=False, inplace=True)
55
+ metadata.reset_index(drop=True, inplace=True)
56
+ metadata.drop("orig_index", axis=1, inplace=True)
57
+
58
+ # FUTURE: truly support WindowsDataset objects
59
+ return FeaturesDataset(
60
+ features_df,
61
+ metadata=metadata,
62
+ description=win_ds.description,
63
+ raw_info=win_ds.raw.info,
64
+ raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
65
+ window_kwargs=win_ds.window_kwargs,
66
+ features_kwargs=feature_extractor.features_kwargs,
67
+ )
68
+
69
+
70
+ def extract_features(
71
+ concat_dataset: BaseConcatDataset,
72
+ features: FeatureExtractor | Dict[str, Callable] | List[Callable],
73
+ *,
74
+ batch_size: int = 512,
75
+ n_jobs: int = 1,
76
+ ):
77
+ if isinstance(features, list):
78
+ features = dict(enumerate(features))
79
+ if not isinstance(features, FeatureExtractor):
80
+ features = FeatureExtractor(features)
81
+ feature_ds_list = list(
82
+ tqdm(
83
+ Parallel(n_jobs=n_jobs, return_as="generator")(
84
+ delayed(_extract_features_from_windowsdataset)(
85
+ win_ds, features, batch_size
86
+ )
87
+ for win_ds in concat_dataset.datasets
88
+ ),
89
+ total=len(concat_dataset.datasets),
90
+ desc="Extracting features",
91
+ )
92
+ )
93
+ return FeaturesConcatDataset(feature_ds_list)
94
+
95
+
96
+ def fit_feature_extractors(
97
+ concat_dataset: BaseConcatDataset,
98
+ features: FeatureExtractor | Dict[str, Callable] | List[Callable],
99
+ batch_size: int = 8192,
100
+ ):
101
+ if isinstance(features, list):
102
+ features = dict(enumerate(features))
103
+ if not isinstance(features, FeatureExtractor):
104
+ features = FeatureExtractor(features)
105
+ if not features._is_fitable:
106
+ return features
107
+ features.clear()
108
+ concat_dl = DataLoader(
109
+ concat_dataset, batch_size=batch_size, shuffle=False, drop_last=False
110
+ )
111
+ for X, y, _ in tqdm(
112
+ concat_dl, total=len(concat_dl), desc="Fitting feature extractors"
113
+ ):
114
+ features.partial_fit(X.numpy(), y=np.array(y))
115
+ features.fit()
116
+ return features