eegdash 0.4.0.dev150__py3-none-any.whl → 0.4.0.dev162__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +1 -1
- eegdash/api.py +180 -86
- eegdash/bids_eeg_metadata.py +139 -39
- eegdash/const.py +25 -0
- eegdash/data_utils.py +239 -173
- eegdash/dataset/dataset.py +35 -13
- eegdash/dataset/dataset_summary.csv +1 -1
- eegdash/dataset/registry.py +69 -4
- eegdash/downloader.py +95 -9
- eegdash/features/datasets.py +320 -136
- eegdash/features/decorators.py +88 -3
- eegdash/features/extractors.py +201 -55
- eegdash/features/inspect.py +78 -5
- eegdash/features/serialization.py +45 -19
- eegdash/features/utils.py +75 -8
- eegdash/hbn/preprocessing.py +50 -17
- eegdash/hbn/windows.py +145 -32
- eegdash/logging.py +19 -0
- eegdash/mongodb.py +44 -27
- eegdash/paths.py +14 -5
- eegdash/utils.py +16 -1
- {eegdash-0.4.0.dev150.dist-info → eegdash-0.4.0.dev162.dist-info}/METADATA +1 -1
- eegdash-0.4.0.dev162.dist-info/RECORD +37 -0
- eegdash-0.4.0.dev150.dist-info/RECORD +0 -37
- {eegdash-0.4.0.dev150.dist-info → eegdash-0.4.0.dev162.dist-info}/WHEEL +0 -0
- {eegdash-0.4.0.dev150.dist-info → eegdash-0.4.0.dev162.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.4.0.dev150.dist-info → eegdash-0.4.0.dev162.dist-info}/top_level.txt +0 -0
eegdash/features/decorators.py
CHANGED
|
@@ -12,6 +12,21 @@ from .extractors import (
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class FeaturePredecessor:
|
|
15
|
+
"""A decorator to specify parent extractors for a feature function.
|
|
16
|
+
|
|
17
|
+
This decorator attaches a list of parent extractor types to a feature
|
|
18
|
+
extraction function. This information can be used to build a dependency
|
|
19
|
+
graph of features.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
*parent_extractor_type : list of Type
|
|
24
|
+
A list of feature extractor classes (subclasses of
|
|
25
|
+
:class:`~eegdash.features.extractors.FeatureExtractor`) that this
|
|
26
|
+
feature depends on.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
15
30
|
def __init__(self, *parent_extractor_type: List[Type]):
|
|
16
31
|
parent_cls = parent_extractor_type
|
|
17
32
|
if not parent_cls:
|
|
@@ -20,17 +35,58 @@ class FeaturePredecessor:
|
|
|
20
35
|
assert issubclass(p_cls, FeatureExtractor)
|
|
21
36
|
self.parent_extractor_type = parent_cls
|
|
22
37
|
|
|
23
|
-
def __call__(self, func: Callable):
|
|
38
|
+
def __call__(self, func: Callable) -> Callable:
|
|
39
|
+
"""Apply the decorator to a function.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
func : callable
|
|
44
|
+
The feature extraction function to decorate.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
callable
|
|
49
|
+
The decorated function with the `parent_extractor_type` attribute
|
|
50
|
+
set.
|
|
51
|
+
|
|
52
|
+
"""
|
|
24
53
|
f = _get_underlying_func(func)
|
|
25
54
|
f.parent_extractor_type = self.parent_extractor_type
|
|
26
55
|
return func
|
|
27
56
|
|
|
28
57
|
|
|
29
58
|
class FeatureKind:
|
|
59
|
+
"""A decorator to specify the kind of a feature.
|
|
60
|
+
|
|
61
|
+
This decorator attaches a "feature kind" (e.g., univariate, bivariate)
|
|
62
|
+
to a feature extraction function.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
feature_kind : MultivariateFeature
|
|
67
|
+
An instance of a feature kind class, such as
|
|
68
|
+
:class:`~eegdash.features.extractors.UnivariateFeature` or
|
|
69
|
+
:class:`~eegdash.features.extractors.BivariateFeature`.
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
|
|
30
73
|
def __init__(self, feature_kind: MultivariateFeature):
|
|
31
74
|
self.feature_kind = feature_kind
|
|
32
75
|
|
|
33
|
-
def __call__(self, func):
|
|
76
|
+
def __call__(self, func: Callable) -> Callable:
|
|
77
|
+
"""Apply the decorator to a function.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
func : callable
|
|
82
|
+
The feature extraction function to decorate.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
callable
|
|
87
|
+
The decorated function with the `feature_kind` attribute set.
|
|
88
|
+
|
|
89
|
+
"""
|
|
34
90
|
f = _get_underlying_func(func)
|
|
35
91
|
f.feature_kind = self.feature_kind
|
|
36
92
|
return func
|
|
@@ -38,9 +94,33 @@ class FeatureKind:
|
|
|
38
94
|
|
|
39
95
|
# Syntax sugar
|
|
40
96
|
univariate_feature = FeatureKind(UnivariateFeature())
|
|
97
|
+
"""Decorator to mark a feature as univariate.
|
|
98
|
+
|
|
99
|
+
This is a convenience instance of :class:`FeatureKind` pre-configured for
|
|
100
|
+
univariate features.
|
|
101
|
+
"""
|
|
41
102
|
|
|
42
103
|
|
|
43
|
-
def bivariate_feature(func, directed=False):
|
|
104
|
+
def bivariate_feature(func: Callable, directed: bool = False) -> Callable:
|
|
105
|
+
"""Decorator to mark a feature as bivariate.
|
|
106
|
+
|
|
107
|
+
This decorator specifies that the feature operates on pairs of channels.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
func : callable
|
|
112
|
+
The feature extraction function to decorate.
|
|
113
|
+
directed : bool, default False
|
|
114
|
+
If True, the feature is directed (e.g., connectivity from channel A
|
|
115
|
+
to B is different from B to A). If False, the feature is undirected.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
callable
|
|
120
|
+
The decorated function with the appropriate bivariate feature kind
|
|
121
|
+
attached.
|
|
122
|
+
|
|
123
|
+
"""
|
|
44
124
|
if directed:
|
|
45
125
|
kind = DirectedBivariateFeature()
|
|
46
126
|
else:
|
|
@@ -49,3 +129,8 @@ def bivariate_feature(func, directed=False):
|
|
|
49
129
|
|
|
50
130
|
|
|
51
131
|
multivariate_feature = FeatureKind(MultivariateFeature())
|
|
132
|
+
"""Decorator to mark a feature as multivariate.
|
|
133
|
+
|
|
134
|
+
This is a convenience instance of :class:`FeatureKind` pre-configured for
|
|
135
|
+
multivariate features, which operate on all channels simultaneously.
|
|
136
|
+
"""
|
eegdash/features/extractors.py
CHANGED
|
@@ -7,7 +7,23 @@ import numpy as np
|
|
|
7
7
|
from numba.core.dispatcher import Dispatcher
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def _get_underlying_func(func):
|
|
10
|
+
def _get_underlying_func(func: Callable) -> Callable:
|
|
11
|
+
"""Get the underlying function from a potential wrapper.
|
|
12
|
+
|
|
13
|
+
This helper unwraps functions that might be wrapped by `functools.partial`
|
|
14
|
+
or `numba.dispatcher.Dispatcher`.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
func : callable
|
|
19
|
+
The function to unwrap.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
callable
|
|
24
|
+
The underlying Python function.
|
|
25
|
+
|
|
26
|
+
"""
|
|
11
27
|
f = func
|
|
12
28
|
if isinstance(f, partial):
|
|
13
29
|
f = f.func
|
|
@@ -17,22 +33,46 @@ def _get_underlying_func(func):
|
|
|
17
33
|
|
|
18
34
|
|
|
19
35
|
class TrainableFeature(ABC):
|
|
36
|
+
"""Abstract base class for features that require training.
|
|
37
|
+
|
|
38
|
+
This ABC defines the interface for feature extractors that need to be
|
|
39
|
+
fitted on data before they can be used. It includes methods for fitting
|
|
40
|
+
the feature extractor and for resetting its state.
|
|
41
|
+
"""
|
|
42
|
+
|
|
20
43
|
def __init__(self):
|
|
21
44
|
self._is_trained = False
|
|
22
45
|
self.clear()
|
|
23
46
|
|
|
24
47
|
@abstractmethod
|
|
25
48
|
def clear(self):
|
|
49
|
+
"""Reset the internal state of the feature extractor."""
|
|
26
50
|
pass
|
|
27
51
|
|
|
28
52
|
@abstractmethod
|
|
29
53
|
def partial_fit(self, *x, y=None):
|
|
54
|
+
"""Update the feature extractor's state with a batch of data.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
*x : tuple
|
|
59
|
+
The input data for fitting.
|
|
60
|
+
y : any, optional
|
|
61
|
+
The target data, if required for supervised training.
|
|
62
|
+
|
|
63
|
+
"""
|
|
30
64
|
pass
|
|
31
65
|
|
|
32
66
|
def fit(self):
|
|
67
|
+
"""Finalize the training of the feature extractor.
|
|
68
|
+
|
|
69
|
+
This method should be called after all data has been seen via
|
|
70
|
+
`partial_fit`. It marks the feature as fitted.
|
|
71
|
+
"""
|
|
33
72
|
self._is_fitted = True
|
|
34
73
|
|
|
35
74
|
def __call__(self, *args, **kwargs):
|
|
75
|
+
"""Check if the feature is fitted before execution."""
|
|
36
76
|
if not self._is_fitted:
|
|
37
77
|
raise RuntimeError(
|
|
38
78
|
f"{self.__class__} cannot be called, it has to be fitted first."
|
|
@@ -40,6 +80,22 @@ class TrainableFeature(ABC):
|
|
|
40
80
|
|
|
41
81
|
|
|
42
82
|
class FeatureExtractor(TrainableFeature):
|
|
83
|
+
"""A composite feature extractor that applies multiple feature functions.
|
|
84
|
+
|
|
85
|
+
This class orchestrates the application of a dictionary of feature
|
|
86
|
+
extraction functions to input data. It can handle nested extractors,
|
|
87
|
+
pre-processing, and trainable features.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
feature_extractors : dict[str, callable]
|
|
92
|
+
A dictionary where keys are feature names and values are the feature
|
|
93
|
+
extraction functions or other `FeatureExtractor` instances.
|
|
94
|
+
**preprocess_kwargs
|
|
95
|
+
Keyword arguments to be passed to the `preprocess` method.
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
|
|
43
99
|
def __init__(
|
|
44
100
|
self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
|
|
45
101
|
):
|
|
@@ -63,30 +119,64 @@ class FeatureExtractor(TrainableFeature):
|
|
|
63
119
|
if isinstance(fe, partial):
|
|
64
120
|
self.features_kwargs[fn] = fe.keywords
|
|
65
121
|
|
|
66
|
-
def _validate_execution_tree(self, feature_extractors):
|
|
122
|
+
def _validate_execution_tree(self, feature_extractors: dict) -> dict:
|
|
123
|
+
"""Validate the feature dependency graph."""
|
|
67
124
|
for fname, f in feature_extractors.items():
|
|
68
125
|
f = _get_underlying_func(f)
|
|
69
126
|
pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
|
|
70
|
-
|
|
127
|
+
if type(self) not in pe_type:
|
|
128
|
+
raise TypeError(
|
|
129
|
+
f"Feature '{fname}' cannot be a child of {type(self).__name__}"
|
|
130
|
+
)
|
|
71
131
|
return feature_extractors
|
|
72
132
|
|
|
73
|
-
def _check_is_trainable(self, feature_extractors):
|
|
74
|
-
|
|
133
|
+
def _check_is_trainable(self, feature_extractors: dict) -> bool:
|
|
134
|
+
"""Check if any of the contained features are trainable."""
|
|
75
135
|
for fname, f in feature_extractors.items():
|
|
76
136
|
if isinstance(f, FeatureExtractor):
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if is_trainable:
|
|
83
|
-
break
|
|
84
|
-
return is_trainable
|
|
137
|
+
if f._is_trainable:
|
|
138
|
+
return True
|
|
139
|
+
elif isinstance(_get_underlying_func(f), TrainableFeature):
|
|
140
|
+
return True
|
|
141
|
+
return False
|
|
85
142
|
|
|
86
143
|
def preprocess(self, *x, **kwargs):
|
|
144
|
+
"""Apply pre-processing to the input data.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
*x : tuple
|
|
149
|
+
Input data.
|
|
150
|
+
**kwargs
|
|
151
|
+
Additional keyword arguments.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
tuple
|
|
156
|
+
The pre-processed data.
|
|
157
|
+
|
|
158
|
+
"""
|
|
87
159
|
return (*x,)
|
|
88
160
|
|
|
89
|
-
def __call__(self, *x, _batch_size=None, _ch_names=None):
|
|
161
|
+
def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
|
|
162
|
+
"""Apply all feature extractors to the input data.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
*x : tuple
|
|
167
|
+
Input data.
|
|
168
|
+
_batch_size : int, optional
|
|
169
|
+
The number of samples in the batch.
|
|
170
|
+
_ch_names : list of str, optional
|
|
171
|
+
The names of the channels in the input data.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
dict
|
|
176
|
+
A dictionary where keys are feature names and values are the
|
|
177
|
+
computed feature values.
|
|
178
|
+
|
|
179
|
+
"""
|
|
90
180
|
assert _batch_size is not None
|
|
91
181
|
assert _ch_names is not None
|
|
92
182
|
if self._is_trainable:
|
|
@@ -100,59 +190,83 @@ class FeatureExtractor(TrainableFeature):
|
|
|
100
190
|
r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
|
|
101
191
|
else:
|
|
102
192
|
r = f(*z)
|
|
103
|
-
|
|
104
|
-
if hasattr(
|
|
105
|
-
r =
|
|
193
|
+
f_und = _get_underlying_func(f)
|
|
194
|
+
if hasattr(f_und, "feature_kind"):
|
|
195
|
+
r = f_und.feature_kind(r, _ch_names=_ch_names)
|
|
106
196
|
if not isinstance(fname, str) or not fname:
|
|
107
|
-
|
|
108
|
-
fname = ""
|
|
109
|
-
else:
|
|
110
|
-
fname = f.__name__
|
|
197
|
+
fname = getattr(f_und, "__name__", "")
|
|
111
198
|
if isinstance(r, dict):
|
|
112
|
-
if fname
|
|
113
|
-
fname += "_"
|
|
199
|
+
prefix = f"{fname}_" if fname else ""
|
|
114
200
|
for k, v in r.items():
|
|
115
|
-
self._add_feature_to_dict(results_dict,
|
|
201
|
+
self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size)
|
|
116
202
|
else:
|
|
117
203
|
self._add_feature_to_dict(results_dict, fname, r, _batch_size)
|
|
118
204
|
return results_dict
|
|
119
205
|
|
|
120
|
-
def _add_feature_to_dict(
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
206
|
+
def _add_feature_to_dict(
|
|
207
|
+
self, results_dict: dict, name: str, value: any, batch_size: int
|
|
208
|
+
):
|
|
209
|
+
"""Add a computed feature to the results dictionary."""
|
|
210
|
+
if isinstance(value, np.ndarray):
|
|
124
211
|
assert value.shape[0] == batch_size
|
|
125
|
-
|
|
212
|
+
results_dict[name] = value
|
|
126
213
|
|
|
127
214
|
def clear(self):
|
|
215
|
+
"""Clear the state of all trainable sub-features."""
|
|
128
216
|
if not self._is_trainable:
|
|
129
217
|
return
|
|
130
|
-
for
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
f.clear()
|
|
218
|
+
for f in self.feature_extractors_dict.values():
|
|
219
|
+
if isinstance(_get_underlying_func(f), TrainableFeature):
|
|
220
|
+
_get_underlying_func(f).clear()
|
|
134
221
|
|
|
135
222
|
def partial_fit(self, *x, y=None):
|
|
223
|
+
"""Partially fit all trainable sub-features."""
|
|
136
224
|
if not self._is_trainable:
|
|
137
225
|
return
|
|
138
226
|
z = self.preprocess(*x, **self.preprocess_kwargs)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
227
|
+
if not isinstance(z, tuple):
|
|
228
|
+
z = (z,)
|
|
229
|
+
for f in self.feature_extractors_dict.values():
|
|
230
|
+
if isinstance(_get_underlying_func(f), TrainableFeature):
|
|
231
|
+
_get_underlying_func(f).partial_fit(*z, y=y)
|
|
143
232
|
|
|
144
233
|
def fit(self):
|
|
234
|
+
"""Fit all trainable sub-features."""
|
|
145
235
|
if not self._is_trainable:
|
|
146
236
|
return
|
|
147
|
-
for
|
|
148
|
-
|
|
149
|
-
if isinstance(f, TrainableFeature):
|
|
237
|
+
for f in self.feature_extractors_dict.values():
|
|
238
|
+
if isinstance(_get_underlying_func(f), TrainableFeature):
|
|
150
239
|
f.fit()
|
|
151
240
|
super().fit()
|
|
152
241
|
|
|
153
242
|
|
|
154
243
|
class MultivariateFeature:
|
|
155
|
-
|
|
244
|
+
"""A mixin for features that operate on multiple channels.
|
|
245
|
+
|
|
246
|
+
This class provides a `__call__` method that converts a feature array into
|
|
247
|
+
a dictionary with named features, where names are derived from channel
|
|
248
|
+
names.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __call__(
|
|
252
|
+
self, x: np.ndarray, _ch_names: list[str] | None = None
|
|
253
|
+
) -> dict | np.ndarray:
|
|
254
|
+
"""Convert a feature array to a named dictionary.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
x : numpy.ndarray
|
|
259
|
+
The computed feature array.
|
|
260
|
+
_ch_names : list of str, optional
|
|
261
|
+
The list of channel names.
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
dict or numpy.ndarray
|
|
266
|
+
A dictionary of named features, or the original array if feature
|
|
267
|
+
channel names cannot be generated.
|
|
268
|
+
|
|
269
|
+
"""
|
|
156
270
|
assert _ch_names is not None
|
|
157
271
|
f_channels = self.feature_channel_names(_ch_names)
|
|
158
272
|
if isinstance(x, dict):
|
|
@@ -163,37 +277,66 @@ class MultivariateFeature:
|
|
|
163
277
|
return self._array_to_dict(x, f_channels)
|
|
164
278
|
|
|
165
279
|
@staticmethod
|
|
166
|
-
def _array_to_dict(
|
|
280
|
+
def _array_to_dict(
|
|
281
|
+
x: np.ndarray, f_channels: list[str], name: str = ""
|
|
282
|
+
) -> dict | np.ndarray:
|
|
283
|
+
"""Convert a numpy array to a dictionary with named keys."""
|
|
167
284
|
assert isinstance(x, np.ndarray)
|
|
168
|
-
if
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
return {name: x}
|
|
172
|
-
return x
|
|
173
|
-
assert x.shape[1] == len(f_channels)
|
|
285
|
+
if not f_channels:
|
|
286
|
+
return {name: x} if name else x
|
|
287
|
+
assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}"
|
|
174
288
|
x = x.swapaxes(0, 1)
|
|
175
|
-
|
|
289
|
+
prefix = f"{name}_" if name else ""
|
|
290
|
+
names = [f"{prefix}{ch}" for ch in f_channels]
|
|
176
291
|
return dict(zip(names, x))
|
|
177
292
|
|
|
178
|
-
def feature_channel_names(self, ch_names):
|
|
293
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
294
|
+
"""Generate feature names based on channel names.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
ch_names : list of str
|
|
299
|
+
The names of the input channels.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
list of str
|
|
304
|
+
The names for the output features.
|
|
305
|
+
|
|
306
|
+
"""
|
|
179
307
|
return []
|
|
180
308
|
|
|
181
309
|
|
|
182
310
|
class UnivariateFeature(MultivariateFeature):
|
|
183
|
-
|
|
311
|
+
"""A feature kind for operations applied to each channel independently."""
|
|
312
|
+
|
|
313
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
314
|
+
"""Return the channel names themselves as feature names."""
|
|
184
315
|
return ch_names
|
|
185
316
|
|
|
186
317
|
|
|
187
318
|
class BivariateFeature(MultivariateFeature):
|
|
188
|
-
|
|
319
|
+
"""A feature kind for operations on pairs of channels.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
channel_pair_format : str, default="{}<>{}"
|
|
324
|
+
A format string used to create feature names from pairs of
|
|
325
|
+
channel names.
|
|
326
|
+
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
def __init__(self, *args, channel_pair_format: str = "{}<>{}"):
|
|
189
330
|
super().__init__(*args)
|
|
190
331
|
self.channel_pair_format = channel_pair_format
|
|
191
332
|
|
|
192
333
|
@staticmethod
|
|
193
|
-
def get_pair_iterators(n):
|
|
334
|
+
def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]:
|
|
335
|
+
"""Get indices for unique, unordered pairs of channels."""
|
|
194
336
|
return np.triu_indices(n, 1)
|
|
195
337
|
|
|
196
|
-
def feature_channel_names(self, ch_names):
|
|
338
|
+
def feature_channel_names(self, ch_names: list[str]) -> list[str]:
|
|
339
|
+
"""Generate feature names for each pair of channels."""
|
|
197
340
|
return [
|
|
198
341
|
self.channel_pair_format.format(ch_names[i], ch_names[j])
|
|
199
342
|
for i, j in zip(*self.get_pair_iterators(len(ch_names)))
|
|
@@ -201,8 +344,11 @@ class BivariateFeature(MultivariateFeature):
|
|
|
201
344
|
|
|
202
345
|
|
|
203
346
|
class DirectedBivariateFeature(BivariateFeature):
|
|
347
|
+
"""A feature kind for directed operations on pairs of channels."""
|
|
348
|
+
|
|
204
349
|
@staticmethod
|
|
205
|
-
def get_pair_iterators(n):
|
|
350
|
+
def get_pair_iterators(n: int) -> list[np.ndarray]:
|
|
351
|
+
"""Get indices for all ordered pairs of channels (excluding self-pairs)."""
|
|
206
352
|
return [
|
|
207
353
|
np.append(a, b)
|
|
208
354
|
for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
|
eegdash/features/inspect.py
CHANGED
|
@@ -5,7 +5,27 @@ from . import extractors, feature_bank
|
|
|
5
5
|
from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def get_feature_predecessors(feature_or_extractor: Callable):
|
|
8
|
+
def get_feature_predecessors(feature_or_extractor: Callable) -> list:
|
|
9
|
+
"""Get the dependency hierarchy for a feature or feature extractor.
|
|
10
|
+
|
|
11
|
+
This function recursively traverses the `parent_extractor_type` attribute
|
|
12
|
+
of a feature or extractor to build a list representing its dependency
|
|
13
|
+
lineage.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
feature_or_extractor : callable
|
|
18
|
+
The feature function or :class:`FeatureExtractor` class to inspect.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
list
|
|
23
|
+
A nested list representing the dependency tree. For a simple linear
|
|
24
|
+
chain, this will be a flat list from the specific feature up to the
|
|
25
|
+
base `FeatureExtractor`. For multiple dependencies, it will contain
|
|
26
|
+
tuples of sub-dependencies.
|
|
27
|
+
|
|
28
|
+
"""
|
|
9
29
|
current = _get_underlying_func(feature_or_extractor)
|
|
10
30
|
if current is FeatureExtractor:
|
|
11
31
|
return [current]
|
|
@@ -20,18 +40,59 @@ def get_feature_predecessors(feature_or_extractor: Callable):
|
|
|
20
40
|
return [current, tuple(predecessors)]
|
|
21
41
|
|
|
22
42
|
|
|
23
|
-
def get_feature_kind(feature: Callable):
|
|
43
|
+
def get_feature_kind(feature: Callable) -> MultivariateFeature:
|
|
44
|
+
"""Get the 'kind' of a feature function.
|
|
45
|
+
|
|
46
|
+
The feature kind (e.g., univariate, bivariate) is typically attached by a
|
|
47
|
+
decorator.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
feature : callable
|
|
52
|
+
The feature function to inspect.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
MultivariateFeature
|
|
57
|
+
An instance of the feature kind (e.g., `UnivariateFeature()`).
|
|
58
|
+
|
|
59
|
+
"""
|
|
24
60
|
return _get_underlying_func(feature).feature_kind
|
|
25
61
|
|
|
26
62
|
|
|
27
|
-
def get_all_features():
|
|
63
|
+
def get_all_features() -> list[tuple[str, Callable]]:
|
|
64
|
+
"""Get a list of all available feature functions.
|
|
65
|
+
|
|
66
|
+
Scans the `eegdash.features.feature_bank` module for functions that have
|
|
67
|
+
been decorated to have a `feature_kind` attribute.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
list[tuple[str, callable]]
|
|
72
|
+
A list of (name, function) tuples for all discovered features.
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
|
|
28
76
|
def isfeature(x):
|
|
29
77
|
return hasattr(_get_underlying_func(x), "feature_kind")
|
|
30
78
|
|
|
31
79
|
return inspect.getmembers(feature_bank, isfeature)
|
|
32
80
|
|
|
33
81
|
|
|
34
|
-
def get_all_feature_extractors():
|
|
82
|
+
def get_all_feature_extractors() -> list[tuple[str, type[FeatureExtractor]]]:
|
|
83
|
+
"""Get a list of all available `FeatureExtractor` classes.
|
|
84
|
+
|
|
85
|
+
Scans the `eegdash.features.feature_bank` module for all classes that
|
|
86
|
+
subclass :class:`~eegdash.features.extractors.FeatureExtractor`.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
list[tuple[str, type[FeatureExtractor]]]
|
|
91
|
+
A list of (name, class) tuples for all discovered feature extractors,
|
|
92
|
+
including the base `FeatureExtractor` itself.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
|
|
35
96
|
def isfeatureextractor(x):
|
|
36
97
|
return inspect.isclass(x) and issubclass(x, FeatureExtractor)
|
|
37
98
|
|
|
@@ -41,7 +102,19 @@ def get_all_feature_extractors():
|
|
|
41
102
|
]
|
|
42
103
|
|
|
43
104
|
|
|
44
|
-
def get_all_feature_kinds():
|
|
105
|
+
def get_all_feature_kinds() -> list[tuple[str, type[MultivariateFeature]]]:
|
|
106
|
+
"""Get a list of all available feature 'kind' classes.
|
|
107
|
+
|
|
108
|
+
Scans the `eegdash.features.extractors` module for all classes that
|
|
109
|
+
subclass :class:`~eegdash.features.extractors.MultivariateFeature`.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
list[tuple[str, type[MultivariateFeature]]]
|
|
114
|
+
A list of (name, class) tuples for all discovered feature kinds.
|
|
115
|
+
|
|
116
|
+
"""
|
|
117
|
+
|
|
45
118
|
def isfeaturekind(x):
|
|
46
119
|
return inspect.isclass(x) and issubclass(x, MultivariateFeature)
|
|
47
120
|
|