eegdash 0.3.3.dev61__py3-none-any.whl → 0.5.0.dev180784713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eegdash/__init__.py +19 -6
- eegdash/api.py +336 -539
- eegdash/bids_eeg_metadata.py +495 -0
- eegdash/const.py +349 -0
- eegdash/dataset/__init__.py +28 -0
- eegdash/dataset/base.py +311 -0
- eegdash/dataset/bids_dataset.py +641 -0
- eegdash/dataset/dataset.py +692 -0
- eegdash/dataset/dataset_summary.csv +255 -0
- eegdash/dataset/registry.py +287 -0
- eegdash/downloader.py +197 -0
- eegdash/features/__init__.py +15 -13
- eegdash/features/datasets.py +329 -138
- eegdash/features/decorators.py +105 -13
- eegdash/features/extractors.py +233 -63
- eegdash/features/feature_bank/__init__.py +12 -12
- eegdash/features/feature_bank/complexity.py +22 -20
- eegdash/features/feature_bank/connectivity.py +27 -28
- eegdash/features/feature_bank/csp.py +3 -1
- eegdash/features/feature_bank/dimensionality.py +6 -6
- eegdash/features/feature_bank/signal.py +29 -30
- eegdash/features/feature_bank/spectral.py +40 -44
- eegdash/features/feature_bank/utils.py +8 -0
- eegdash/features/inspect.py +126 -15
- eegdash/features/serialization.py +58 -17
- eegdash/features/utils.py +90 -16
- eegdash/hbn/__init__.py +28 -0
- eegdash/hbn/preprocessing.py +105 -0
- eegdash/hbn/windows.py +428 -0
- eegdash/logging.py +54 -0
- eegdash/mongodb.py +55 -24
- eegdash/paths.py +52 -0
- eegdash/utils.py +29 -1
- eegdash-0.5.0.dev180784713.dist-info/METADATA +121 -0
- eegdash-0.5.0.dev180784713.dist-info/RECORD +38 -0
- eegdash-0.5.0.dev180784713.dist-info/licenses/LICENSE +29 -0
- eegdash/data_config.py +0 -34
- eegdash/data_utils.py +0 -687
- eegdash/dataset.py +0 -69
- eegdash/preprocessing.py +0 -63
- eegdash-0.3.3.dev61.dist-info/METADATA +0 -192
- eegdash-0.3.3.dev61.dist-info/RECORD +0 -28
- eegdash-0.3.3.dev61.dist-info/licenses/LICENSE +0 -23
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/WHEEL +0 -0
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/top_level.txt +0 -0
eegdash/features/decorators.py
CHANGED
|
@@ -1,36 +1,99 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import List
|
|
3
3
|
|
|
4
4
|
from .extractors import (
|
|
5
5
|
BivariateFeature,
|
|
6
6
|
DirectedBivariateFeature,
|
|
7
|
-
FeatureExtractor,
|
|
8
7
|
MultivariateFeature,
|
|
9
8
|
UnivariateFeature,
|
|
10
9
|
_get_underlying_func,
|
|
11
10
|
)
|
|
12
11
|
|
|
12
|
+
__all__ = [
|
|
13
|
+
"bivariate_feature",
|
|
14
|
+
"FeatureKind",
|
|
15
|
+
"FeaturePredecessor",
|
|
16
|
+
"multivariate_feature",
|
|
17
|
+
"univariate_feature",
|
|
18
|
+
]
|
|
19
|
+
|
|
13
20
|
|
|
14
21
|
class FeaturePredecessor:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
"""A decorator to specify parent extractors for a feature function.
|
|
23
|
+
|
|
24
|
+
This decorator attaches a list of immediate parent preprocessing steps to a feature
|
|
25
|
+
extraction function. This information can be used to build a dependency graph of
|
|
26
|
+
features.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
*parent_extractor_type : list of Type
|
|
31
|
+
A list of preprocessing functions (subclasses of
|
|
32
|
+
:class:`~collections.abc.Callable` or None) that this feature immediately depends
|
|
33
|
+
on.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *parent_extractor_type: List[Callable | None]):
|
|
38
|
+
parent_func = parent_extractor_type
|
|
39
|
+
if not parent_func:
|
|
40
|
+
parent_func = [None]
|
|
41
|
+
for p_func in parent_func:
|
|
42
|
+
assert p_func is None or callable(p_func)
|
|
43
|
+
self.parent_extractor_type = parent_func
|
|
44
|
+
|
|
45
|
+
def __call__(self, func: Callable) -> Callable:
|
|
46
|
+
"""Apply the decorator to a function.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
func : callable
|
|
51
|
+
The feature extraction function to decorate.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
callable
|
|
56
|
+
The decorated function with the `parent_extractor_type` attribute
|
|
57
|
+
set.
|
|
58
|
+
|
|
59
|
+
"""
|
|
24
60
|
f = _get_underlying_func(func)
|
|
25
61
|
f.parent_extractor_type = self.parent_extractor_type
|
|
26
62
|
return func
|
|
27
63
|
|
|
28
64
|
|
|
29
65
|
class FeatureKind:
|
|
66
|
+
"""A decorator to specify the kind of a feature.
|
|
67
|
+
|
|
68
|
+
This decorator attaches a "feature kind" (e.g., univariate, bivariate)
|
|
69
|
+
to a feature extraction function.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
feature_kind : ~eegdash.features.extractors.MultivariateFeature
|
|
74
|
+
An instance of a feature kind class, such as
|
|
75
|
+
:class:`~eegdash.features.extractors.UnivariateFeature` or
|
|
76
|
+
:class:`~eegdash.features.extractors.BivariateFeature`.
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
|
|
30
80
|
def __init__(self, feature_kind: MultivariateFeature):
|
|
31
81
|
self.feature_kind = feature_kind
|
|
32
82
|
|
|
33
|
-
def __call__(self, func):
|
|
83
|
+
def __call__(self, func: Callable) -> Callable:
|
|
84
|
+
"""Apply the decorator to a function.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
func : callable
|
|
89
|
+
The feature extraction function to decorate.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
callable
|
|
94
|
+
The decorated function with the `feature_kind` attribute set.
|
|
95
|
+
|
|
96
|
+
"""
|
|
34
97
|
f = _get_underlying_func(func)
|
|
35
98
|
f.feature_kind = self.feature_kind
|
|
36
99
|
return func
|
|
@@ -38,9 +101,33 @@ class FeatureKind:
|
|
|
38
101
|
|
|
39
102
|
# Syntax sugar
|
|
40
103
|
univariate_feature = FeatureKind(UnivariateFeature())
|
|
104
|
+
"""Decorator to mark a feature as univariate.
|
|
105
|
+
|
|
106
|
+
This is a convenience instance of :class:`~eegdash.features.decorators.FeatureKind` pre-configured for
|
|
107
|
+
univariate features.
|
|
108
|
+
"""
|
|
41
109
|
|
|
42
110
|
|
|
43
|
-
def bivariate_feature(func, directed=False):
|
|
111
|
+
def bivariate_feature(func: Callable, directed: bool = False) -> Callable:
|
|
112
|
+
"""Decorator to mark a feature as bivariate.
|
|
113
|
+
|
|
114
|
+
This decorator specifies that the feature operates on pairs of channels.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
func : callable
|
|
119
|
+
The feature extraction function to decorate.
|
|
120
|
+
directed : bool, default False
|
|
121
|
+
If True, the feature is directed (e.g., connectivity from channel A
|
|
122
|
+
to B is different from B to A). If False, the feature is undirected.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
callable
|
|
127
|
+
The decorated function with the appropriate bivariate feature kind
|
|
128
|
+
attached.
|
|
129
|
+
|
|
130
|
+
"""
|
|
44
131
|
if directed:
|
|
45
132
|
kind = DirectedBivariateFeature()
|
|
46
133
|
else:
|
|
@@ -49,3 +136,8 @@ def bivariate_feature(func, directed=False):
|
|
|
49
136
|
|
|
50
137
|
|
|
51
138
|
multivariate_feature = FeatureKind(MultivariateFeature())
|
|
139
|
+
"""Decorator to mark a feature as multivariate.
|
|
140
|
+
|
|
141
|
+
This is a convenience instance of :class:`~eegdash.features.decorators.FeatureKind` pre-configured for
|
|
142
|
+
multivariate features, which operate on all channels simultaneously.
|
|
143
|
+
"""
|
eegdash/features/extractors.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from abc import ABC, abstractmethod
|
|
2
4
|
from collections.abc import Callable
|
|
3
5
|
from functools import partial
|
|
@@ -6,8 +8,33 @@ from typing import Dict
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
from numba.core.dispatcher import Dispatcher
|
|
8
10
|
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BivariateFeature",
|
|
13
|
+
"DirectedBivariateFeature",
|
|
14
|
+
"FeatureExtractor",
|
|
15
|
+
"MultivariateFeature",
|
|
16
|
+
"TrainableFeature",
|
|
17
|
+
"UnivariateFeature",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _get_underlying_func(func: Callable) -> Callable:
|
|
22
|
+
"""Get the underlying function from a potential wrapper.
|
|
23
|
+
|
|
24
|
+
This helper unwraps functions that might be wrapped by `functools.partial`
|
|
25
|
+
or `numba.dispatcher.Dispatcher`.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
func : callable
|
|
30
|
+
The function to unwrap.
|
|
9
31
|
|
|
10
|
-
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
callable
|
|
35
|
+
The underlying Python function.
|
|
36
|
+
|
|
37
|
+
"""
|
|
11
38
|
f = func
|
|
12
39
|
if isinstance(f, partial):
|
|
13
40
|
f = f.func
|
|
@@ -17,22 +44,46 @@ def _get_underlying_func(func):
|
|
|
17
44
|
|
|
18
45
|
|
|
19
46
|
class TrainableFeature(ABC):
|
|
47
|
+
"""Abstract base class for features that require training.
|
|
48
|
+
|
|
49
|
+
This ABC defines the interface for feature extractors that need to be
|
|
50
|
+
fitted on data before they can be used. It includes methods for fitting
|
|
51
|
+
the feature extractor and for resetting its state.
|
|
52
|
+
"""
|
|
53
|
+
|
|
20
54
|
def __init__(self):
|
|
21
55
|
self._is_trained = False
|
|
22
56
|
self.clear()
|
|
23
57
|
|
|
24
58
|
@abstractmethod
|
|
25
59
|
def clear(self):
|
|
60
|
+
"""Reset the internal state of the feature extractor."""
|
|
26
61
|
pass
|
|
27
62
|
|
|
28
63
|
@abstractmethod
|
|
29
64
|
def partial_fit(self, *x, y=None):
|
|
65
|
+
"""Update the feature extractor's state with a batch of data.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
*x : tuple
|
|
70
|
+
The input data for fitting.
|
|
71
|
+
y : any, optional
|
|
72
|
+
The target data, if required for supervised training.
|
|
73
|
+
|
|
74
|
+
"""
|
|
30
75
|
pass
|
|
31
76
|
|
|
32
77
|
def fit(self):
|
|
78
|
+
"""Finalize the training of the feature extractor.
|
|
79
|
+
|
|
80
|
+
This method should be called after all data has been seen via
|
|
81
|
+
`partial_fit`. It marks the feature as fitted.
|
|
82
|
+
"""
|
|
33
83
|
self._is_fitted = True
|
|
34
84
|
|
|
35
85
|
def __call__(self, *args, **kwargs):
|
|
86
|
+
"""Check if the feature is fitted before execution."""
|
|
36
87
|
if not self._is_fitted:
|
|
37
88
|
raise RuntimeError(
|
|
38
89
|
f"{self.__class__} cannot be called, it has to be fitted first."
|
|
@@ -40,59 +91,119 @@ class TrainableFeature(ABC):
|
|
|
40
91
|
|
|
41
92
|
|
|
42
93
|
class FeatureExtractor(TrainableFeature):
|
|
94
|
+
"""A composite feature extractor that applies multiple feature functions.
|
|
95
|
+
|
|
96
|
+
This class orchestrates the application of a dictionary of feature
|
|
97
|
+
extraction functions to input data. It can handle nested extractors,
|
|
98
|
+
pre-processing, and trainable features.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
feature_extractors : dict[str, callable]
|
|
103
|
+
A dictionary where keys are feature names and values are the feature
|
|
104
|
+
extraction functions or other `FeatureExtractor` instances.
|
|
105
|
+
preprocessor
|
|
106
|
+
A shared preprocessing function for all child feature extraction functions.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
|
|
43
110
|
def __init__(
|
|
44
|
-
self,
|
|
111
|
+
self,
|
|
112
|
+
feature_extractors: Dict[str, Callable],
|
|
113
|
+
preprocessor: Callable | None = None,
|
|
45
114
|
):
|
|
115
|
+
self.preprocessor = preprocessor
|
|
46
116
|
self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
|
|
47
117
|
self._is_trainable = self._check_is_trainable(feature_extractors)
|
|
48
118
|
super().__init__()
|
|
49
119
|
|
|
50
120
|
# bypassing FeaturePredecessor to avoid circular import
|
|
51
121
|
if not hasattr(self, "parent_extractor_type"):
|
|
52
|
-
self.parent_extractor_type = [
|
|
53
|
-
|
|
54
|
-
self.
|
|
55
|
-
if
|
|
56
|
-
self.preprocess_kwargs =
|
|
57
|
-
self.features_kwargs = {
|
|
58
|
-
"preprocess_kwargs": preprocess_kwargs,
|
|
59
|
-
}
|
|
122
|
+
self.parent_extractor_type = [None]
|
|
123
|
+
|
|
124
|
+
self.features_kwargs = dict()
|
|
125
|
+
if preprocessor is not None and isinstance(preprocessor, partial):
|
|
126
|
+
self.features_kwargs["preprocess_kwargs"] = preprocessor.args
|
|
60
127
|
for fn, fe in feature_extractors.items():
|
|
61
128
|
if isinstance(fe, FeatureExtractor):
|
|
62
129
|
self.features_kwargs[fn] = fe.features_kwargs
|
|
63
130
|
if isinstance(fe, partial):
|
|
64
131
|
self.features_kwargs[fn] = fe.keywords
|
|
65
132
|
|
|
66
|
-
def _validate_execution_tree(self, feature_extractors):
|
|
133
|
+
def _validate_execution_tree(self, feature_extractors: dict) -> dict:
|
|
134
|
+
"""Validate the feature dependency graph."""
|
|
135
|
+
preprocessor = (
|
|
136
|
+
None
|
|
137
|
+
if self.preprocessor is None
|
|
138
|
+
else _get_underlying_func(self.preprocessor)
|
|
139
|
+
)
|
|
67
140
|
for fname, f in feature_extractors.items():
|
|
141
|
+
if isinstance(f, FeatureExtractor):
|
|
142
|
+
f = f.preprocessor
|
|
68
143
|
f = _get_underlying_func(f)
|
|
69
|
-
pe_type = getattr(f, "parent_extractor_type", [
|
|
70
|
-
|
|
144
|
+
pe_type = getattr(f, "parent_extractor_type", [None])
|
|
145
|
+
if preprocessor not in pe_type:
|
|
146
|
+
parent = getattr(preprocessor, "__name__", preprocessor)
|
|
147
|
+
child = getattr(f, "__name__", f)
|
|
148
|
+
raise TypeError(
|
|
149
|
+
f"Feature '{fname}: {child}' cannot be a child of {parent}"
|
|
150
|
+
)
|
|
71
151
|
return feature_extractors
|
|
72
152
|
|
|
73
|
-
def _check_is_trainable(self, feature_extractors):
|
|
74
|
-
|
|
153
|
+
def _check_is_trainable(self, feature_extractors: dict) -> bool:
|
|
154
|
+
"""Check if any of the contained features are trainable."""
|
|
75
155
|
for fname, f in feature_extractors.items():
|
|
76
156
|
if isinstance(f, FeatureExtractor):
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
157
|
+
if f._is_trainable:
|
|
158
|
+
return True
|
|
159
|
+
elif isinstance(_get_underlying_func(f), TrainableFeature):
|
|
160
|
+
return True
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def preprocess(self, *x):
|
|
164
|
+
"""Apply pre-processing to the input data.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
*x : tuple
|
|
169
|
+
Input data.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
tuple
|
|
174
|
+
The pre-processed data.
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
if self.preprocessor is None:
|
|
178
|
+
return (*x,)
|
|
179
|
+
else:
|
|
180
|
+
return self.preprocessor(*x)
|
|
181
|
+
|
|
182
|
+
def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
|
|
183
|
+
"""Apply all feature extractors to the input data.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
*x : tuple
|
|
188
|
+
Input data.
|
|
189
|
+
_batch_size : int, optional
|
|
190
|
+
The number of samples in the batch.
|
|
191
|
+
_ch_names : list of str, optional
|
|
192
|
+
The names of the channels in the input data.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
dict
|
|
197
|
+
A dictionary where keys are feature names and values are the
|
|
198
|
+
computed feature values.
|
|
199
|
+
|
|
200
|
+
"""
|
|
90
201
|
assert _batch_size is not None
|
|
91
202
|
assert _ch_names is not None
|
|
92
203
|
if self._is_trainable:
|
|
93
204
|
super().__call__()
|
|
94
205
|
results_dict = dict()
|
|
95
|
-
z = self.preprocess(*x
|
|
206
|
+
z = self.preprocess(*x)
|
|
96
207
|
if not isinstance(z, tuple):
|
|
97
208
|
z = (z,)
|
|
98
209
|
for fname, f in self.feature_extractors_dict.items():
|
|
@@ -100,51 +211,53 @@ class FeatureExtractor(TrainableFeature):
|
|
|
100
211
|
r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
|
|
101
212
|
else:
|
|
102
213
|
r = f(*z)
|
|
103
|
-
|
|
104
|
-
if hasattr(
|
|
105
|
-
r =
|
|
214
|
+
f_und = _get_underlying_func(f)
|
|
215
|
+
if hasattr(f_und, "feature_kind"):
|
|
216
|
+
r = f_und.feature_kind(r, _ch_names=_ch_names)
|
|
106
217
|
if not isinstance(fname, str) or not fname:
|
|
107
|
-
|
|
108
|
-
fname = ""
|
|
109
|
-
else:
|
|
110
|
-
fname = f.__name__
|
|
218
|
+
fname = getattr(f_und, "__name__", "")
|
|
111
219
|
if isinstance(r, dict):
|
|
112
|
-
if fname
|
|
113
|
-
fname += "_"
|
|
220
|
+
prefix = f"{fname}_" if fname else ""
|
|
114
221
|
for k, v in r.items():
|
|
115
|
-
self._add_feature_to_dict(results_dict,
|
|
222
|
+
self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size)
|
|
116
223
|
else:
|
|
117
224
|
self._add_feature_to_dict(results_dict, fname, r, _batch_size)
|
|
118
225
|
return results_dict
|
|
119
226
|
|
|
120
|
-
def _add_feature_to_dict(
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
227
|
+
def _add_feature_to_dict(
|
|
228
|
+
self, results_dict: dict, name: str, value: any, batch_size: int
|
|
229
|
+
):
|
|
230
|
+
"""Add a computed feature to the results dictionary."""
|
|
231
|
+
if isinstance(value, np.ndarray):
|
|
124
232
|
assert value.shape[0] == batch_size
|
|
125
|
-
|
|
233
|
+
results_dict[name] = value
|
|
126
234
|
|
|
127
235
|
def clear(self):
|
|
236
|
+
"""Clear the state of all trainable sub-features."""
|
|
128
237
|
if not self._is_trainable:
|
|
129
238
|
return
|
|
130
|
-
for
|
|
239
|
+
for f in self.feature_extractors_dict.values():
|
|
131
240
|
f = _get_underlying_func(f)
|
|
132
241
|
if isinstance(f, TrainableFeature):
|
|
133
242
|
f.clear()
|
|
134
243
|
|
|
135
244
|
def partial_fit(self, *x, y=None):
|
|
245
|
+
"""Partially fit all trainable sub-features."""
|
|
136
246
|
if not self._is_trainable:
|
|
137
247
|
return
|
|
138
|
-
z = self.preprocess(*x
|
|
139
|
-
|
|
248
|
+
z = self.preprocess(*x)
|
|
249
|
+
if not isinstance(z, tuple):
|
|
250
|
+
z = (z,)
|
|
251
|
+
for f in self.feature_extractors_dict.values():
|
|
140
252
|
f = _get_underlying_func(f)
|
|
141
253
|
if isinstance(f, TrainableFeature):
|
|
142
254
|
f.partial_fit(*z, y=y)
|
|
143
255
|
|
|
144
256
|
def fit(self):
|
|
257
|
+
"""Fit all trainable sub-features."""
|
|
145
258
|
if not self._is_trainable:
|
|
146
259
|
return
|
|
147
|
-
for
|
|
260
|
+
for f in self.feature_extractors_dict.values():
|
|
148
261
|
f = _get_underlying_func(f)
|
|
149
262
|
if isinstance(f, TrainableFeature):
|
|
150
263
|
f.fit()
|
|
@@ -152,7 +265,32 @@ class FeatureExtractor(TrainableFeature):
|
|
|
152
265
|
|
|
153
266
|
|
|
154
267
|
class MultivariateFeature:
|
|
155
|
-
|
|
268
|
+
"""A mixin for features that operate on multiple channels.
|
|
269
|
+
|
|
270
|
+
This class provides a `__call__` method that converts a feature array into
|
|
271
|
+
a dictionary with named features, where names are derived from channel
|
|
272
|
+
names.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __call__(
|
|
276
|
+
self, x: np.ndarray, _ch_names: list[str] | None = None
|
|
277
|
+
) -> dict | np.ndarray:
|
|
278
|
+
"""Convert a feature array to a named dictionary.
|
|
279
|
+
|
|
280
|
+
Parameters
|
|
281
|
+
----------
|
|
282
|
+
x : numpy.ndarray
|
|
283
|
+
The computed feature array.
|
|
284
|
+
_ch_names : list of str, optional
|
|
285
|
+
The list of channel names.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
dict or numpy.ndarray
|
|
290
|
+
A dictionary of named features, or the original array if feature
|
|
291
|
+
channel names cannot be generated.
|
|
292
|
+
|
|
293
|
+
"""
|
|
156
294
|
assert _ch_names is not None
|
|
157
295
|
f_channels = self.feature_channel_names(_ch_names)
|
|
158
296
|
if isinstance(x, dict):
|
|
@@ -163,37 +301,66 @@ class MultivariateFeature:
|
|
|
163
301
|
return self._array_to_dict(x, f_channels)
|
|
164
302
|
|
|
165
303
|
@staticmethod
|
|
166
|
-
def _array_to_dict(
|
|
304
|
+
def _array_to_dict(
|
|
305
|
+
x: np.ndarray, f_channels: list[str], name: str = ""
|
|
306
|
+
) -> dict | np.ndarray:
|
|
307
|
+
"""Convert a numpy array to a dictionary with named keys."""
|
|
167
308
|
assert isinstance(x, np.ndarray)
|
|
168
|
-
if
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
return {name: x}
|
|
172
|
-
return x
|
|
173
|
-
assert x.shape[1] == len(f_channels)
|
|
309
|
+
if not f_channels:
|
|
310
|
+
return {name: x} if name else x
|
|
311
|
+
assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}"
|
|
174
312
|
x = x.swapaxes(0, 1)
|
|
175
|
-
|
|
313
|
+
prefix = f"{name}_" if name else ""
|
|
314
|
+
names = [f"{prefix}{ch}" for ch in f_channels]
|
|
176
315
|
return dict(zip(names, x))
|
|
177
316
|
|
|
178
|
-
def feature_channel_names(self, ch_names):
|
|
317
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
318
|
+
"""Generate feature names based on channel names.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
ch_names : list of str
|
|
323
|
+
The names of the input channels.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
list of str
|
|
328
|
+
The names for the output features.
|
|
329
|
+
|
|
330
|
+
"""
|
|
179
331
|
return []
|
|
180
332
|
|
|
181
333
|
|
|
182
334
|
class UnivariateFeature(MultivariateFeature):
|
|
183
|
-
|
|
335
|
+
"""A feature kind for operations applied to each channel independently."""
|
|
336
|
+
|
|
337
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
338
|
+
"""Return the channel names themselves as feature names."""
|
|
184
339
|
return ch_names
|
|
185
340
|
|
|
186
341
|
|
|
187
342
|
class BivariateFeature(MultivariateFeature):
|
|
188
|
-
|
|
343
|
+
"""A feature kind for operations on pairs of channels.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
channel_pair_format : str, default="{}<>{}"
|
|
348
|
+
A format string used to create feature names from pairs of
|
|
349
|
+
channel names.
|
|
350
|
+
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def __init__(self, *args, channel_pair_format: str = "{}<>{}"):
|
|
189
354
|
super().__init__(*args)
|
|
190
355
|
self.channel_pair_format = channel_pair_format
|
|
191
356
|
|
|
192
357
|
@staticmethod
|
|
193
|
-
def get_pair_iterators(n):
|
|
358
|
+
def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]:
|
|
359
|
+
"""Get indices for unique, unordered pairs of channels."""
|
|
194
360
|
return np.triu_indices(n, 1)
|
|
195
361
|
|
|
196
|
-
def feature_channel_names(self, ch_names):
|
|
362
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
363
|
+
"""Generate feature names for each pair of channels."""
|
|
197
364
|
return [
|
|
198
365
|
self.channel_pair_format.format(ch_names[i], ch_names[j])
|
|
199
366
|
for i, j in zip(*self.get_pair_iterators(len(ch_names)))
|
|
@@ -201,8 +368,11 @@ class BivariateFeature(MultivariateFeature):
|
|
|
201
368
|
|
|
202
369
|
|
|
203
370
|
class DirectedBivariateFeature(BivariateFeature):
|
|
371
|
+
"""A feature kind for directed operations on pairs of channels."""
|
|
372
|
+
|
|
204
373
|
@staticmethod
|
|
205
|
-
def get_pair_iterators(n):
|
|
374
|
+
def get_pair_iterators(n: int) -> list[np.ndarray]:
|
|
375
|
+
"""Get indices for all ordered pairs of channels (excluding self-pairs)."""
|
|
206
376
|
return [
|
|
207
377
|
np.append(a, b)
|
|
208
378
|
for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
|
|
@@ -6,14 +6,14 @@ functions so users can import them directly from
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from .complexity import (
|
|
9
|
-
EntropyFeatureExtractor,
|
|
10
9
|
complexity_approx_entropy,
|
|
10
|
+
complexity_entropy_preprocessor,
|
|
11
11
|
complexity_lempel_ziv,
|
|
12
12
|
complexity_sample_entropy,
|
|
13
13
|
complexity_svd_entropy,
|
|
14
14
|
)
|
|
15
15
|
from .connectivity import (
|
|
16
|
-
|
|
16
|
+
connectivity_coherency_preprocessor,
|
|
17
17
|
connectivity_imaginary_coherence,
|
|
18
18
|
connectivity_lagged_coherence,
|
|
19
19
|
connectivity_magnitude_square_coherence,
|
|
@@ -27,8 +27,8 @@ from .dimensionality import (
|
|
|
27
27
|
dimensionality_petrosian_fractal_dim,
|
|
28
28
|
)
|
|
29
29
|
from .signal import (
|
|
30
|
-
HilbertFeatureExtractor,
|
|
31
30
|
signal_decorrelation_time,
|
|
31
|
+
signal_hilbert_preprocessor,
|
|
32
32
|
signal_hjorth_activity,
|
|
33
33
|
signal_hjorth_complexity,
|
|
34
34
|
signal_hjorth_mobility,
|
|
@@ -44,29 +44,29 @@ from .signal import (
|
|
|
44
44
|
signal_zero_crossings,
|
|
45
45
|
)
|
|
46
46
|
from .spectral import (
|
|
47
|
-
DBSpectralFeatureExtractor,
|
|
48
|
-
NormalizedSpectralFeatureExtractor,
|
|
49
|
-
SpectralFeatureExtractor,
|
|
50
47
|
spectral_bands_power,
|
|
48
|
+
spectral_db_preprocessor,
|
|
51
49
|
spectral_edge,
|
|
52
50
|
spectral_entropy,
|
|
53
51
|
spectral_hjorth_activity,
|
|
54
52
|
spectral_hjorth_complexity,
|
|
55
53
|
spectral_hjorth_mobility,
|
|
56
54
|
spectral_moment,
|
|
55
|
+
spectral_normalized_preprocessor,
|
|
56
|
+
spectral_preprocessor,
|
|
57
57
|
spectral_root_total_power,
|
|
58
58
|
spectral_slope,
|
|
59
59
|
)
|
|
60
60
|
|
|
61
61
|
__all__ = [
|
|
62
62
|
# Complexity
|
|
63
|
-
"
|
|
63
|
+
"complexity_entropy_preprocessor",
|
|
64
64
|
"complexity_approx_entropy",
|
|
65
65
|
"complexity_sample_entropy",
|
|
66
66
|
"complexity_svd_entropy",
|
|
67
67
|
"complexity_lempel_ziv",
|
|
68
68
|
# Connectivity
|
|
69
|
-
"
|
|
69
|
+
"connectivity_coherency_preprocessor",
|
|
70
70
|
"connectivity_magnitude_square_coherence",
|
|
71
71
|
"connectivity_imaginary_coherence",
|
|
72
72
|
"connectivity_lagged_coherence",
|
|
@@ -79,7 +79,7 @@ __all__ = [
|
|
|
79
79
|
"dimensionality_hurst_exp",
|
|
80
80
|
"dimensionality_detrended_fluctuation_analysis",
|
|
81
81
|
# Signal
|
|
82
|
-
"
|
|
82
|
+
"signal_hilbert_preprocessor",
|
|
83
83
|
"signal_mean",
|
|
84
84
|
"signal_variance",
|
|
85
85
|
"signal_skewness",
|
|
@@ -95,9 +95,9 @@ __all__ = [
|
|
|
95
95
|
"signal_hjorth_complexity",
|
|
96
96
|
"signal_decorrelation_time",
|
|
97
97
|
# Spectral
|
|
98
|
-
"
|
|
99
|
-
"
|
|
100
|
-
"
|
|
98
|
+
"spectral_preprocessor",
|
|
99
|
+
"spectral_normalized_preprocessor",
|
|
100
|
+
"spectral_db_preprocessor",
|
|
101
101
|
"spectral_root_total_power",
|
|
102
102
|
"spectral_moment",
|
|
103
103
|
"spectral_entropy",
|