eegdash 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/data_config.py +28 -0
- eegdash/data_utils.py +55 -56
- eegdash/features/__init__.py +25 -0
- eegdash/features/datasets.py +453 -0
- eegdash/features/decorators.py +43 -0
- eegdash/features/extractors.py +209 -0
- eegdash/features/feature_bank/__init__.py +6 -0
- eegdash/features/feature_bank/complexity.py +97 -0
- eegdash/features/feature_bank/connectivity.py +99 -0
- eegdash/features/feature_bank/csp.py +102 -0
- eegdash/features/feature_bank/dimensionality.py +108 -0
- eegdash/features/feature_bank/signal.py +103 -0
- eegdash/features/feature_bank/spectral.py +134 -0
- eegdash/features/serialization.py +87 -0
- eegdash/features/utils.py +114 -0
- eegdash/main.py +98 -50
- {eegdash-0.0.8.dist-info → eegdash-0.0.9.dist-info}/METADATA +13 -47
- eegdash-0.0.9.dist-info/RECORD +22 -0
- {eegdash-0.0.8.dist-info → eegdash-0.0.9.dist-info}/WHEEL +1 -1
- eegdash-0.0.8.dist-info/RECORD +0 -8
- {eegdash-0.0.8.dist-info → eegdash-0.0.9.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.0.8.dist-info → eegdash-0.0.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import numbers
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy import stats
|
|
4
|
+
|
|
5
|
+
from ..decorators import univariate_feature
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"signal_mean",
|
|
10
|
+
"signal_variance",
|
|
11
|
+
"signal_skewness",
|
|
12
|
+
"signal_kurtosis",
|
|
13
|
+
"signal_std",
|
|
14
|
+
"signal_root_mean_square",
|
|
15
|
+
"signal_peak_to_peak",
|
|
16
|
+
"signal_quantile",
|
|
17
|
+
"signal_zero_crossings",
|
|
18
|
+
"signal_line_length",
|
|
19
|
+
"signal_hjorth_activity",
|
|
20
|
+
"signal_hjorth_mobility",
|
|
21
|
+
"signal_hjorth_complexity",
|
|
22
|
+
"signal_decorrelation_time",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@univariate_feature
|
|
27
|
+
def signal_mean(x):
|
|
28
|
+
return x.mean(axis=-1)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@univariate_feature
|
|
32
|
+
def signal_variance(x, **kwargs):
|
|
33
|
+
return x.var(axis=-1, **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@univariate_feature
|
|
37
|
+
def signal_std(x, **kwargs):
|
|
38
|
+
return x.std(axis=-1, **kwargs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@univariate_feature
|
|
42
|
+
def signal_skewness(x, **kwargs):
|
|
43
|
+
return stats.skew(x, axis=x.ndim - 1, **kwargs)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@univariate_feature
|
|
47
|
+
def signal_kurtosis(x, **kwargs):
|
|
48
|
+
return stats.kurtosis(x, axis=x.ndim - 1, **kwargs)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@univariate_feature
|
|
52
|
+
def signal_root_mean_square(x):
|
|
53
|
+
return np.sqrt(np.power(x, 2).mean(axis=-1))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@univariate_feature
|
|
57
|
+
def signal_peak_to_peak(x, **kwargs):
|
|
58
|
+
return np.ptp(x, axis=-1, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@univariate_feature
|
|
62
|
+
def signal_quantile(x, q: numbers.Number = 0.5, **kwargs):
|
|
63
|
+
return np.quantile(x, q=q, axis=-1, **kwargs)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@univariate_feature
|
|
67
|
+
def signal_line_length(x):
|
|
68
|
+
return np.abs(np.diff(x, axis=-1)).mean(axis=-1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@univariate_feature
|
|
72
|
+
def signal_zero_crossings(x, threshold=1e-15):
|
|
73
|
+
zero_ind = np.logical_and(x > -threshold, x < threshold)
|
|
74
|
+
zero_cross = np.diff(zero_ind, axis=-1).astype(int).sum(axis=-1)
|
|
75
|
+
y = x.copy()
|
|
76
|
+
y[zero_ind] = 0
|
|
77
|
+
zero_cross += np.sum(np.signbit(y[..., :-1]) != np.signbit(y[..., 1:]), axis=-1)
|
|
78
|
+
return zero_cross
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@univariate_feature
|
|
82
|
+
def signal_hjorth_mobility(x):
|
|
83
|
+
return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@univariate_feature
|
|
87
|
+
def signal_hjorth_complexity(x):
|
|
88
|
+
return np.diff(x, 2, axis=-1).std(axis=-1) / x.std(axis=-1)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@univariate_feature
|
|
92
|
+
def signal_decorrelation_time(x, fs=1):
|
|
93
|
+
f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1)
|
|
94
|
+
ac = np.fft.ifft(f.real**2 + f.imag**2, axis=-1)[..., : x.shape[-1] // 2]
|
|
95
|
+
dct = np.empty(x.shape[:-1])
|
|
96
|
+
for i in np.ndindex(x.shape[:-1]):
|
|
97
|
+
dct[i] = np.searchsorted(ac[i] <= 0, True)
|
|
98
|
+
return dct / fs
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ================================= Aliases =================================
|
|
102
|
+
|
|
103
|
+
signal_hjorth_activity = signal_variance
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import numba as nb
|
|
3
|
+
from scipy.signal import welch
|
|
4
|
+
|
|
5
|
+
from ..extractors import FeatureExtractor
|
|
6
|
+
from ..decorators import FeaturePredecessor, univariate_feature
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"SpectralFeatureExtractor",
|
|
11
|
+
"NormalizedSpectralFeatureExtractor",
|
|
12
|
+
"DBSpectralFeatureExtractor",
|
|
13
|
+
"spectral_root_total_power",
|
|
14
|
+
"spectral_moment",
|
|
15
|
+
"spectral_entropy",
|
|
16
|
+
"spectral_edge",
|
|
17
|
+
"spectral_slope",
|
|
18
|
+
"spectral_bands_power",
|
|
19
|
+
"spectral_hjorth_activity",
|
|
20
|
+
"spectral_hjorth_mobility",
|
|
21
|
+
"spectral_hjorth_complexity",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SpectralFeatureExtractor(FeatureExtractor):
|
|
26
|
+
def preprocess(self, x, **kwargs):
|
|
27
|
+
f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
|
|
28
|
+
f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
|
|
29
|
+
kwargs["axis"] = -1
|
|
30
|
+
f, p = welch(x, **kwargs)
|
|
31
|
+
if f_min is not None or f_max is not None:
|
|
32
|
+
f_min_idx = f > f_min if f_min is not None else True
|
|
33
|
+
f_max_idx = f < f_max if f_max is not None else True
|
|
34
|
+
idx = np.logical_and(f_min_idx, f_max_idx)
|
|
35
|
+
f = f[idx]
|
|
36
|
+
p = p[..., idx]
|
|
37
|
+
return f, p
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@FeaturePredecessor(SpectralFeatureExtractor)
|
|
41
|
+
class NormalizedSpectralFeatureExtractor(FeatureExtractor):
|
|
42
|
+
def preprocess(self, *x):
|
|
43
|
+
return (*x[:-1], x[-1] / x[-1].sum(axis=-1, keepdims=True))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@FeaturePredecessor(SpectralFeatureExtractor)
|
|
47
|
+
class DBSpectralFeatureExtractor(FeatureExtractor):
|
|
48
|
+
def preprocess(self, *x, eps=1e-15):
|
|
49
|
+
return (*x[:-1], 10 * np.log10(x[-1] + eps))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@FeaturePredecessor(SpectralFeatureExtractor)
|
|
53
|
+
@univariate_feature
|
|
54
|
+
def spectral_root_total_power(f, p):
|
|
55
|
+
return np.sqrt(p.sum(axis=-1))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@FeaturePredecessor(NormalizedSpectralFeatureExtractor)
|
|
59
|
+
@univariate_feature
|
|
60
|
+
def spectral_moment(f, p):
|
|
61
|
+
return np.sum(f * p, axis=-1)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@FeaturePredecessor(SpectralFeatureExtractor)
|
|
65
|
+
@univariate_feature
|
|
66
|
+
def spectral_hjorth_activity(f, p):
|
|
67
|
+
return np.sum(p, axis=-1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@FeaturePredecessor(NormalizedSpectralFeatureExtractor)
|
|
71
|
+
@univariate_feature
|
|
72
|
+
def spectral_hjorth_mobility(f, p):
|
|
73
|
+
return np.sqrt(np.sum(np.power(f, 2) * p, axis=-1))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@FeaturePredecessor(NormalizedSpectralFeatureExtractor)
|
|
77
|
+
@univariate_feature
|
|
78
|
+
def spectral_hjorth_complexity(f, p):
|
|
79
|
+
return np.sqrt(np.sum(np.power(f, 4) * p, axis=-1))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@FeaturePredecessor(NormalizedSpectralFeatureExtractor)
|
|
83
|
+
@univariate_feature
|
|
84
|
+
def spectral_entropy(f, p):
|
|
85
|
+
idx = p > 0
|
|
86
|
+
plogp = np.zeros_like(p)
|
|
87
|
+
plogp[idx] = p[idx] * np.log(p[idx])
|
|
88
|
+
return -np.sum(plogp, axis=-1)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@FeaturePredecessor(NormalizedSpectralFeatureExtractor)
|
|
92
|
+
@univariate_feature
|
|
93
|
+
@nb.njit(cache=True, fastmath=True)
|
|
94
|
+
def spectral_edge(f, p, edge=0.9):
|
|
95
|
+
se = np.empty(p.shape[:-1])
|
|
96
|
+
for i in np.ndindex(p.shape[:-1]):
|
|
97
|
+
se[i] = f[np.searchsorted(np.cumsum(p[i]), edge)]
|
|
98
|
+
return se
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@FeaturePredecessor(DBSpectralFeatureExtractor)
|
|
102
|
+
@univariate_feature
|
|
103
|
+
def spectral_slope(f, p):
|
|
104
|
+
log_f = np.vstack((np.log(f), np.ones(f.shape[0]))).T
|
|
105
|
+
r = np.linalg.lstsq(log_f, p.reshape(-1, p.shape[-1]).T)[0]
|
|
106
|
+
r = r.reshape(2, *p.shape[:-1])
|
|
107
|
+
return {"exp": r[0], "int": r[1]}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@FeaturePredecessor(
|
|
111
|
+
SpectralFeatureExtractor,
|
|
112
|
+
NormalizedSpectralFeatureExtractor,
|
|
113
|
+
DBSpectralFeatureExtractor,
|
|
114
|
+
)
|
|
115
|
+
@univariate_feature
|
|
116
|
+
def spectral_bands_power(
|
|
117
|
+
f,
|
|
118
|
+
p,
|
|
119
|
+
bands={
|
|
120
|
+
"delta": (1, 4.5),
|
|
121
|
+
"theta": (4.5, 8),
|
|
122
|
+
"alpha": (8, 12),
|
|
123
|
+
"beta": (12, 30),
|
|
124
|
+
},
|
|
125
|
+
):
|
|
126
|
+
bands_power = dict()
|
|
127
|
+
for k, v in bands.items():
|
|
128
|
+
assert isinstance(k, str)
|
|
129
|
+
assert isinstance(v, tuple)
|
|
130
|
+
assert len(v) == 2
|
|
131
|
+
mask = np.logical_and(f > v[0], f < v[1])
|
|
132
|
+
power = p[..., mask].sum(axis=-1)
|
|
133
|
+
bands_power[k] = power
|
|
134
|
+
return bands_power
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions for storing and loading of features datasets.
|
|
3
|
+
|
|
4
|
+
see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from joblib import Parallel, delayed
|
|
12
|
+
|
|
13
|
+
from mne.io import read_info
|
|
14
|
+
from braindecode.datautil.serialization import _load_kwargs_json
|
|
15
|
+
|
|
16
|
+
from .datasets import (
|
|
17
|
+
FeaturesDataset,
|
|
18
|
+
FeaturesConcatDataset,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
|
|
23
|
+
"""Load a stored FeaturesConcatDataset of FeaturesDatasets from files.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
path: str | pathlib.Path
|
|
28
|
+
Path to the directory of the .fif / -epo.fif and .json files.
|
|
29
|
+
ids_to_load: list of int | None
|
|
30
|
+
Ids of specific files to load.
|
|
31
|
+
n_jobs: int
|
|
32
|
+
Number of jobs to be used to read files in parallel.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
concat_dataset: FeaturesConcatDataset of FeaturesDatasets
|
|
37
|
+
"""
|
|
38
|
+
# Make sure we always work with a pathlib.Path
|
|
39
|
+
path = Path(path)
|
|
40
|
+
|
|
41
|
+
# else we have a dataset saved in the new way with subdirectories in path
|
|
42
|
+
# for every dataset with description.json and -feat.parquet,
|
|
43
|
+
# target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
|
|
44
|
+
# window_preproc_kwargs.json, features_kwargs.json
|
|
45
|
+
if ids_to_load is None:
|
|
46
|
+
ids_to_load = [p.name for p in path.iterdir()]
|
|
47
|
+
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
48
|
+
ids_to_load = [str(i) for i in ids_to_load]
|
|
49
|
+
|
|
50
|
+
datasets = Parallel(n_jobs)(delayed(_load_parallel)(path, i) for i in ids_to_load)
|
|
51
|
+
return FeaturesConcatDataset(datasets)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _load_parallel(path, i):
|
|
55
|
+
sub_dir = path / i
|
|
56
|
+
|
|
57
|
+
parquet_name_pattern = "{}-feat.parquet"
|
|
58
|
+
parquet_file_name = parquet_name_pattern.format(i)
|
|
59
|
+
parquet_file_path = sub_dir / parquet_file_name
|
|
60
|
+
|
|
61
|
+
features = pd.read_parquet(parquet_file_path)
|
|
62
|
+
|
|
63
|
+
description_file_path = sub_dir / "description.json"
|
|
64
|
+
description = pd.read_json(description_file_path, typ="series")
|
|
65
|
+
|
|
66
|
+
raw_info_file_path = sub_dir / "raw-info.fif"
|
|
67
|
+
raw_info = None
|
|
68
|
+
if raw_info_file_path.exists():
|
|
69
|
+
raw_info = read_info(raw_info_file_path)
|
|
70
|
+
|
|
71
|
+
raw_preproc_kwargs = _load_kwargs_json("raw_preproc_kwargs", sub_dir)
|
|
72
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
73
|
+
window_preproc_kwargs = _load_kwargs_json("window_preproc_kwargs", sub_dir)
|
|
74
|
+
features_kwargs = _load_kwargs_json("features_kwargs", sub_dir)
|
|
75
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
76
|
+
|
|
77
|
+
dataset = FeaturesDataset(
|
|
78
|
+
features,
|
|
79
|
+
metadata=metadata,
|
|
80
|
+
description=description,
|
|
81
|
+
raw_info=raw_info,
|
|
82
|
+
raw_preproc_kwargs=raw_preproc_kwargs,
|
|
83
|
+
window_kwargs=window_kwargs,
|
|
84
|
+
window_preproc_kwargs=window_preproc_kwargs,
|
|
85
|
+
features_kwargs=features_kwargs,
|
|
86
|
+
)
|
|
87
|
+
return dataset
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import copy
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from joblib import Parallel, delayed
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from braindecode.datasets.base import (
|
|
10
|
+
EEGWindowsDataset,
|
|
11
|
+
WindowsDataset,
|
|
12
|
+
BaseConcatDataset,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .datasets import FeaturesDataset, FeaturesConcatDataset
|
|
16
|
+
from .extractors import FeatureExtractor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _extract_features_from_windowsdataset(
|
|
20
|
+
win_ds: EEGWindowsDataset | WindowsDataset,
|
|
21
|
+
feature_extractor: FeatureExtractor,
|
|
22
|
+
batch_size: int = 512,
|
|
23
|
+
):
|
|
24
|
+
metadata = win_ds.metadata
|
|
25
|
+
if not win_ds.targets_from == "metadata":
|
|
26
|
+
metadata = copy.deepcopy(metadata)
|
|
27
|
+
metadata["orig_index"] = metadata.index
|
|
28
|
+
metadata.set_index(
|
|
29
|
+
["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"],
|
|
30
|
+
drop=False,
|
|
31
|
+
inplace=True,
|
|
32
|
+
)
|
|
33
|
+
win_dl = DataLoader(win_ds, batch_size=batch_size, shuffle=False, drop_last=False)
|
|
34
|
+
features_dict = dict()
|
|
35
|
+
ch_names = win_ds.raw.ch_names
|
|
36
|
+
for X, y, crop_inds in win_dl:
|
|
37
|
+
X = X.numpy()
|
|
38
|
+
if hasattr(y, "tolist"):
|
|
39
|
+
y = y.tolist()
|
|
40
|
+
win_dict = dict()
|
|
41
|
+
win_dict.update(
|
|
42
|
+
feature_extractor(X, _batch_size=X.shape[0], _ch_names=ch_names)
|
|
43
|
+
)
|
|
44
|
+
if not win_ds.targets_from == "metadata":
|
|
45
|
+
metadata.loc[crop_inds, "target"] = y
|
|
46
|
+
for k, v in win_dict.items():
|
|
47
|
+
if k not in features_dict:
|
|
48
|
+
features_dict[k] = []
|
|
49
|
+
features_dict[k].extend(v)
|
|
50
|
+
features_df = pd.DataFrame(features_dict)
|
|
51
|
+
if not win_ds.targets_from == "metadata":
|
|
52
|
+
metadata.set_index("orig_index", drop=False, inplace=True)
|
|
53
|
+
metadata.reset_index(drop=True, inplace=True)
|
|
54
|
+
metadata.drop("orig_index", axis=1, inplace=True)
|
|
55
|
+
|
|
56
|
+
# FUTURE: truely support WindowsDataset objects
|
|
57
|
+
return FeaturesDataset(
|
|
58
|
+
features_df,
|
|
59
|
+
metadata=metadata,
|
|
60
|
+
description=win_ds.description,
|
|
61
|
+
raw_info=win_ds.raw.info,
|
|
62
|
+
raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
|
|
63
|
+
window_kwargs=win_ds.window_kwargs,
|
|
64
|
+
features_kwargs=feature_extractor.features_kwargs,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def extract_features(
|
|
69
|
+
concat_dataset: BaseConcatDataset,
|
|
70
|
+
features: FeatureExtractor | Dict[str, Callable] | List[Callable],
|
|
71
|
+
*,
|
|
72
|
+
batch_size: int = 512,
|
|
73
|
+
n_jobs: int = 1,
|
|
74
|
+
):
|
|
75
|
+
if isinstance(features, list):
|
|
76
|
+
features = dict(enumerate(features))
|
|
77
|
+
if not isinstance(features, FeatureExtractor):
|
|
78
|
+
features = FeatureExtractor(features)
|
|
79
|
+
feature_ds_list = list(
|
|
80
|
+
tqdm(
|
|
81
|
+
Parallel(n_jobs=n_jobs, return_as="generator")(
|
|
82
|
+
delayed(_extract_features_from_windowsdataset)(
|
|
83
|
+
win_ds, features, batch_size
|
|
84
|
+
)
|
|
85
|
+
for win_ds in concat_dataset.datasets
|
|
86
|
+
),
|
|
87
|
+
total=len(concat_dataset.datasets),
|
|
88
|
+
desc="Extracting features",
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
return FeaturesConcatDataset(feature_ds_list)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def fit_feature_extractors(
|
|
95
|
+
concat_dataset: BaseConcatDataset,
|
|
96
|
+
features: FeatureExtractor | Dict[str, Callable] | List[Callable],
|
|
97
|
+
batch_size: int = 8192,
|
|
98
|
+
):
|
|
99
|
+
if isinstance(features, list):
|
|
100
|
+
features = dict(enumerate(features))
|
|
101
|
+
if not isinstance(features, FeatureExtractor):
|
|
102
|
+
features = FeatureExtractor(features)
|
|
103
|
+
if not features._is_fitable:
|
|
104
|
+
return features
|
|
105
|
+
features.clear()
|
|
106
|
+
concat_dl = DataLoader(
|
|
107
|
+
concat_dataset, batch_size=batch_size, shuffle=False, drop_last=False
|
|
108
|
+
)
|
|
109
|
+
for X, y, _ in tqdm(
|
|
110
|
+
concat_dl, total=len(concat_dl), desc="Fitting feature extractors"
|
|
111
|
+
):
|
|
112
|
+
features.partial_fit(X.numpy(), y=np.array(y))
|
|
113
|
+
features.fit()
|
|
114
|
+
return features
|
eegdash/main.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
|
-
from typing import List
|
|
2
1
|
import pymongo
|
|
3
2
|
from dotenv import load_dotenv
|
|
4
3
|
import os
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
import s3fs
|
|
7
6
|
from joblib import Parallel, delayed
|
|
7
|
+
import json
|
|
8
8
|
import tempfile
|
|
9
9
|
import mne
|
|
10
10
|
import numpy as np
|
|
11
11
|
import xarray as xr
|
|
12
|
-
from .data_utils import
|
|
12
|
+
from .data_utils import EEGBIDSDataset, EEGDashBaseRaw, EEGDashBaseDataset
|
|
13
|
+
from .data_config import config as data_config
|
|
13
14
|
from braindecode.datasets import BaseDataset, BaseConcatDataset
|
|
14
15
|
from collections import defaultdict
|
|
15
16
|
from pymongo import MongoClient, InsertOne, UpdateOne, DeleteOne
|
|
@@ -18,6 +19,12 @@ class EEGDash:
|
|
|
18
19
|
AWS_BUCKET = 's3://openneuro.org'
|
|
19
20
|
def __init__(self,
|
|
20
21
|
is_public=True):
|
|
22
|
+
# Load config file
|
|
23
|
+
# config_path = Path(__file__).parent / 'config.json'
|
|
24
|
+
# with open(config_path, 'r') as f:
|
|
25
|
+
# self.config = json.load(f)
|
|
26
|
+
|
|
27
|
+
self.config = data_config
|
|
21
28
|
if is_public:
|
|
22
29
|
DB_CONNECTION_STRING="mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
|
|
23
30
|
else:
|
|
@@ -37,10 +44,9 @@ class EEGDash:
|
|
|
37
44
|
# convert to list using get_item on each element
|
|
38
45
|
return [result for result in results]
|
|
39
46
|
|
|
40
|
-
def exist(self,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
}
|
|
47
|
+
def exist(self, query:dict):
|
|
48
|
+
accepted_query_fields = ['data_name', 'dataset']
|
|
49
|
+
assert all(field in accepted_query_fields for field in query.keys())
|
|
44
50
|
sessions = self.find(query)
|
|
45
51
|
return len(sessions) > 0
|
|
46
52
|
|
|
@@ -104,66 +110,111 @@ class EEGDash:
|
|
|
104
110
|
)
|
|
105
111
|
return eeg_xarray
|
|
106
112
|
|
|
107
|
-
def
|
|
113
|
+
def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
|
|
114
|
+
bids_file = Path(bids_file)
|
|
115
|
+
extensions = {
|
|
116
|
+
'.set': ['.set', '.fdt'], # eeglab
|
|
117
|
+
'.edf': ['.edf'], # european
|
|
118
|
+
'.vhdr': ['.eeg', '.vhdr', '.vmrk', '.dat', '.raw'], # brainvision
|
|
119
|
+
'.bdf': ['.bdf'], # biosemi
|
|
120
|
+
}
|
|
121
|
+
return [str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix))) for suffix in extensions[bids_file.suffix] if bids_file.with_suffix(suffix).exists()]
|
|
122
|
+
|
|
123
|
+
def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
|
|
108
124
|
'''
|
|
109
125
|
bids_file must be a file of the bids_dataset
|
|
110
126
|
'''
|
|
111
127
|
if bids_file not in bids_dataset.files:
|
|
112
128
|
raise ValueError(f'{bids_file} not in {bids_dataset.dataset}')
|
|
129
|
+
|
|
130
|
+
# Initialize attrs with None values for all expected fields
|
|
131
|
+
attrs = {field: None for field in self.config['attributes'].keys()}
|
|
132
|
+
|
|
113
133
|
f = os.path.basename(bids_file)
|
|
114
134
|
dsnumber = bids_dataset.dataset
|
|
115
135
|
# extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
|
|
116
136
|
openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
|
|
117
137
|
|
|
118
|
-
|
|
138
|
+
# Update with actual values where available
|
|
139
|
+
try:
|
|
140
|
+
participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
print(f"Error getting participants_tsv: {str(e)}")
|
|
143
|
+
participants_tsv = None
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
eeg_json = bids_dataset.eeg_json(bids_file)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
print(f"Error getting eeg_json: {str(e)}")
|
|
149
|
+
eeg_json = None
|
|
150
|
+
|
|
151
|
+
bids_dependencies_files = self.config['bids_dependencies_files']
|
|
119
152
|
bidsdependencies = []
|
|
120
153
|
for extension in bids_dependencies_files:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
'
|
|
133
|
-
'
|
|
134
|
-
'
|
|
135
|
-
'
|
|
136
|
-
'
|
|
137
|
-
'
|
|
138
|
-
'
|
|
139
|
-
'
|
|
140
|
-
'
|
|
141
|
-
'
|
|
154
|
+
try:
|
|
155
|
+
dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
|
|
156
|
+
dep_path = [str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path]
|
|
157
|
+
bidsdependencies.extend(dep_path)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
|
|
162
|
+
|
|
163
|
+
# Define field extraction functions with error handling
|
|
164
|
+
field_extractors = {
|
|
165
|
+
'data_name': lambda: f'{bids_dataset.dataset}_{f}',
|
|
166
|
+
'dataset': lambda: bids_dataset.dataset,
|
|
167
|
+
'bidspath': lambda: openneuro_path,
|
|
168
|
+
'subject': lambda: bids_dataset.get_bids_file_attribute('subject', bids_file),
|
|
169
|
+
'task': lambda: bids_dataset.get_bids_file_attribute('task', bids_file),
|
|
170
|
+
'session': lambda: bids_dataset.get_bids_file_attribute('session', bids_file),
|
|
171
|
+
'run': lambda: bids_dataset.get_bids_file_attribute('run', bids_file),
|
|
172
|
+
'modality': lambda: bids_dataset.get_bids_file_attribute('modality', bids_file),
|
|
173
|
+
'sampling_frequency': lambda: bids_dataset.get_bids_file_attribute('sfreq', bids_file),
|
|
174
|
+
'nchans': lambda: bids_dataset.get_bids_file_attribute('nchans', bids_file),
|
|
175
|
+
'ntimes': lambda: bids_dataset.get_bids_file_attribute('ntimes', bids_file),
|
|
176
|
+
'participant_tsv': lambda: participants_tsv,
|
|
177
|
+
'eeg_json': lambda: eeg_json,
|
|
178
|
+
'bidsdependencies': lambda: bidsdependencies,
|
|
142
179
|
}
|
|
180
|
+
|
|
181
|
+
# Dynamically populate attrs with error handling
|
|
182
|
+
for field, extractor in field_extractors.items():
|
|
183
|
+
try:
|
|
184
|
+
attrs[field] = extractor()
|
|
185
|
+
except Exception as e:
|
|
186
|
+
print(f"Error extracting {field}: {str(e)}")
|
|
187
|
+
attrs[field] = None
|
|
143
188
|
|
|
144
189
|
return attrs
|
|
145
190
|
|
|
146
|
-
def add_bids_dataset(self, dataset, data_dir,
|
|
191
|
+
def add_bids_dataset(self, dataset, data_dir, overwrite=True):
|
|
147
192
|
'''
|
|
148
193
|
Create new records for the dataset in the MongoDB database if not found
|
|
149
194
|
'''
|
|
150
195
|
if self.is_public:
|
|
151
196
|
raise ValueError('This operation is not allowed for public users')
|
|
152
197
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
198
|
+
if not overwrite and self.exist({'dataset': dataset}):
|
|
199
|
+
print(f'Dataset {dataset} already exists in the database')
|
|
200
|
+
return
|
|
201
|
+
try:
|
|
202
|
+
bids_dataset = EEGBIDSDataset(
|
|
203
|
+
data_dir=data_dir,
|
|
204
|
+
dataset=dataset,
|
|
205
|
+
)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
print(f'Error creating bids dataset {dataset}: {str(e)}')
|
|
208
|
+
raise e
|
|
158
209
|
requests = []
|
|
159
210
|
for bids_file in bids_dataset.get_files():
|
|
160
211
|
try:
|
|
161
212
|
data_id = f"{dataset}_{os.path.basename(bids_file)}"
|
|
162
213
|
|
|
163
|
-
if self.exist(data_name
|
|
214
|
+
if self.exist({'data_name':data_id}):
|
|
164
215
|
if overwrite:
|
|
165
216
|
eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
166
|
-
requests.append(
|
|
217
|
+
requests.append(self.update_request(eeg_attrs))
|
|
167
218
|
else:
|
|
168
219
|
eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
169
220
|
requests.append(self.add_request(eeg_attrs))
|
|
@@ -224,17 +275,22 @@ class EEGDash:
|
|
|
224
275
|
def remove_field_from_db(self, field):
|
|
225
276
|
self.__collection.update_many({}, {'$unset': {field: 1}})
|
|
226
277
|
|
|
278
|
+
@property
|
|
279
|
+
def collection(self):
|
|
280
|
+
return self.__collection
|
|
227
281
|
|
|
228
282
|
class EEGDashDataset(BaseConcatDataset):
|
|
229
|
-
CACHE_DIR = '.eegdash_cache'
|
|
283
|
+
# CACHE_DIR = '.eegdash_cache'
|
|
230
284
|
def __init__(
|
|
231
285
|
self,
|
|
232
286
|
query:dict=None,
|
|
233
287
|
data_dir:str | list =None,
|
|
234
288
|
dataset:str | list =None,
|
|
235
289
|
description_fields: list[str]=['subject', 'session', 'run', 'task', 'age', 'gender', 'sex'],
|
|
290
|
+
cache_dir:str='.eegdash_cache',
|
|
236
291
|
**kwargs
|
|
237
292
|
):
|
|
293
|
+
self.cache_dir = cache_dir
|
|
238
294
|
if query:
|
|
239
295
|
datasets = self.find_datasets(query, description_fields, **kwargs)
|
|
240
296
|
elif data_dir:
|
|
@@ -247,6 +303,7 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
247
303
|
datasets.extend(self.load_bids_dataset(dataset[i], data_dir[i], description_fields))
|
|
248
304
|
# convert to list using get_item on each element
|
|
249
305
|
super().__init__(datasets)
|
|
306
|
+
|
|
250
307
|
|
|
251
308
|
def find_key_in_nested_dict(self, data, target_key):
|
|
252
309
|
if isinstance(data, dict):
|
|
@@ -267,7 +324,7 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
267
324
|
value = self.find_key_in_nested_dict(record, field)
|
|
268
325
|
if value:
|
|
269
326
|
description[field] = value
|
|
270
|
-
datasets.append(EEGDashBaseDataset(record, self.
|
|
327
|
+
datasets.append(EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs))
|
|
271
328
|
return datasets
|
|
272
329
|
|
|
273
330
|
def load_bids_dataset(self, dataset, data_dir, description_fields: list[str],raw_format='eeglab', **kwargs):
|
|
@@ -280,9 +337,9 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
280
337
|
value = self.find_key_in_nested_dict(record, field)
|
|
281
338
|
if value:
|
|
282
339
|
description[field] = value
|
|
283
|
-
return EEGDashBaseDataset(record, self.
|
|
340
|
+
return EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs)
|
|
284
341
|
|
|
285
|
-
bids_dataset =
|
|
342
|
+
bids_dataset = EEGBIDSDataset(
|
|
286
343
|
data_dir=data_dir,
|
|
287
344
|
dataset=dataset,
|
|
288
345
|
raw_format=raw_format,
|
|
@@ -291,15 +348,6 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
291
348
|
datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
|
|
292
349
|
delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file) for bids_file in bids_dataset.get_files()
|
|
293
350
|
)
|
|
294
|
-
# datasets = []
|
|
295
|
-
# for bids_file in bids_dataset.get_files():
|
|
296
|
-
# record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
297
|
-
# description = {}
|
|
298
|
-
# for field in description_fields:
|
|
299
|
-
# value = self.find_key_in_nested_dict(record, field)
|
|
300
|
-
# if value:
|
|
301
|
-
# description[field] = value
|
|
302
|
-
# datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
|
|
303
351
|
return datasets
|
|
304
352
|
|
|
305
353
|
def main():
|