eegdash 0.4.0.dev153__py3-none-any.whl → 0.4.0.dev162__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -12,6 +12,21 @@ from .extractors import (
12
12
 
13
13
 
14
14
  class FeaturePredecessor:
15
+ """A decorator to specify parent extractors for a feature function.
16
+
17
+ This decorator attaches a list of parent extractor types to a feature
18
+ extraction function. This information can be used to build a dependency
19
+ graph of features.
20
+
21
+ Parameters
22
+ ----------
23
+ *parent_extractor_type : list of Type
24
+ A list of feature extractor classes (subclasses of
25
+ :class:`~eegdash.features.extractors.FeatureExtractor`) that this
26
+ feature depends on.
27
+
28
+ """
29
+
15
30
  def __init__(self, *parent_extractor_type: List[Type]):
16
31
  parent_cls = parent_extractor_type
17
32
  if not parent_cls:
@@ -20,17 +35,58 @@ class FeaturePredecessor:
20
35
  assert issubclass(p_cls, FeatureExtractor)
21
36
  self.parent_extractor_type = parent_cls
22
37
 
23
- def __call__(self, func: Callable):
38
+ def __call__(self, func: Callable) -> Callable:
39
+ """Apply the decorator to a function.
40
+
41
+ Parameters
42
+ ----------
43
+ func : callable
44
+ The feature extraction function to decorate.
45
+
46
+ Returns
47
+ -------
48
+ callable
49
+ The decorated function with the `parent_extractor_type` attribute
50
+ set.
51
+
52
+ """
24
53
  f = _get_underlying_func(func)
25
54
  f.parent_extractor_type = self.parent_extractor_type
26
55
  return func
27
56
 
28
57
 
29
58
  class FeatureKind:
59
+ """A decorator to specify the kind of a feature.
60
+
61
+ This decorator attaches a "feature kind" (e.g., univariate, bivariate)
62
+ to a feature extraction function.
63
+
64
+ Parameters
65
+ ----------
66
+ feature_kind : MultivariateFeature
67
+ An instance of a feature kind class, such as
68
+ :class:`~eegdash.features.extractors.UnivariateFeature` or
69
+ :class:`~eegdash.features.extractors.BivariateFeature`.
70
+
71
+ """
72
+
30
73
  def __init__(self, feature_kind: MultivariateFeature):
31
74
  self.feature_kind = feature_kind
32
75
 
33
- def __call__(self, func):
76
+ def __call__(self, func: Callable) -> Callable:
77
+ """Apply the decorator to a function.
78
+
79
+ Parameters
80
+ ----------
81
+ func : callable
82
+ The feature extraction function to decorate.
83
+
84
+ Returns
85
+ -------
86
+ callable
87
+ The decorated function with the `feature_kind` attribute set.
88
+
89
+ """
34
90
  f = _get_underlying_func(func)
35
91
  f.feature_kind = self.feature_kind
36
92
  return func
@@ -38,9 +94,33 @@ class FeatureKind:
38
94
 
39
95
  # Syntax sugar
40
96
  univariate_feature = FeatureKind(UnivariateFeature())
97
+ """Decorator to mark a feature as univariate.
98
+
99
+ This is a convenience instance of :class:`FeatureKind` pre-configured for
100
+ univariate features.
101
+ """
41
102
 
42
103
 
43
- def bivariate_feature(func, directed=False):
104
+ def bivariate_feature(func: Callable, directed: bool = False) -> Callable:
105
+ """Decorator to mark a feature as bivariate.
106
+
107
+ This decorator specifies that the feature operates on pairs of channels.
108
+
109
+ Parameters
110
+ ----------
111
+ func : callable
112
+ The feature extraction function to decorate.
113
+ directed : bool, default False
114
+ If True, the feature is directed (e.g., connectivity from channel A
115
+ to B is different from B to A). If False, the feature is undirected.
116
+
117
+ Returns
118
+ -------
119
+ callable
120
+ The decorated function with the appropriate bivariate feature kind
121
+ attached.
122
+
123
+ """
44
124
  if directed:
45
125
  kind = DirectedBivariateFeature()
46
126
  else:
@@ -49,3 +129,8 @@ def bivariate_feature(func, directed=False):
49
129
 
50
130
 
51
131
  multivariate_feature = FeatureKind(MultivariateFeature())
132
+ """Decorator to mark a feature as multivariate.
133
+
134
+ This is a convenience instance of :class:`FeatureKind` pre-configured for
135
+ multivariate features, which operate on all channels simultaneously.
136
+ """
@@ -7,7 +7,23 @@ import numpy as np
7
7
  from numba.core.dispatcher import Dispatcher
8
8
 
9
9
 
10
- def _get_underlying_func(func):
10
+ def _get_underlying_func(func: Callable) -> Callable:
11
+ """Get the underlying function from a potential wrapper.
12
+
13
+ This helper unwraps functions that might be wrapped by `functools.partial`
14
+ or `numba.dispatcher.Dispatcher`.
15
+
16
+ Parameters
17
+ ----------
18
+ func : callable
19
+ The function to unwrap.
20
+
21
+ Returns
22
+ -------
23
+ callable
24
+ The underlying Python function.
25
+
26
+ """
11
27
  f = func
12
28
  if isinstance(f, partial):
13
29
  f = f.func
@@ -17,22 +33,46 @@ def _get_underlying_func(func):
17
33
 
18
34
 
19
35
  class TrainableFeature(ABC):
36
+ """Abstract base class for features that require training.
37
+
38
+ This ABC defines the interface for feature extractors that need to be
39
+ fitted on data before they can be used. It includes methods for fitting
40
+ the feature extractor and for resetting its state.
41
+ """
42
+
20
43
  def __init__(self):
21
44
  self._is_trained = False
22
45
  self.clear()
23
46
 
24
47
  @abstractmethod
25
48
  def clear(self):
49
+ """Reset the internal state of the feature extractor."""
26
50
  pass
27
51
 
28
52
  @abstractmethod
29
53
  def partial_fit(self, *x, y=None):
54
+ """Update the feature extractor's state with a batch of data.
55
+
56
+ Parameters
57
+ ----------
58
+ *x : tuple
59
+ The input data for fitting.
60
+ y : any, optional
61
+ The target data, if required for supervised training.
62
+
63
+ """
30
64
  pass
31
65
 
32
66
  def fit(self):
67
+ """Finalize the training of the feature extractor.
68
+
69
+ This method should be called after all data has been seen via
70
+ `partial_fit`. It marks the feature as fitted.
71
+ """
33
72
  self._is_fitted = True
34
73
 
35
74
  def __call__(self, *args, **kwargs):
75
+ """Check if the feature is fitted before execution."""
36
76
  if not self._is_fitted:
37
77
  raise RuntimeError(
38
78
  f"{self.__class__} cannot be called, it has to be fitted first."
@@ -40,6 +80,22 @@ class TrainableFeature(ABC):
40
80
 
41
81
 
42
82
  class FeatureExtractor(TrainableFeature):
83
+ """A composite feature extractor that applies multiple feature functions.
84
+
85
+ This class orchestrates the application of a dictionary of feature
86
+ extraction functions to input data. It can handle nested extractors,
87
+ pre-processing, and trainable features.
88
+
89
+ Parameters
90
+ ----------
91
+ feature_extractors : dict[str, callable]
92
+ A dictionary where keys are feature names and values are the feature
93
+ extraction functions or other `FeatureExtractor` instances.
94
+ **preprocess_kwargs
95
+ Keyword arguments to be passed to the `preprocess` method.
96
+
97
+ """
98
+
43
99
  def __init__(
44
100
  self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
45
101
  ):
@@ -63,30 +119,64 @@ class FeatureExtractor(TrainableFeature):
63
119
  if isinstance(fe, partial):
64
120
  self.features_kwargs[fn] = fe.keywords
65
121
 
66
- def _validate_execution_tree(self, feature_extractors):
122
+ def _validate_execution_tree(self, feature_extractors: dict) -> dict:
123
+ """Validate the feature dependency graph."""
67
124
  for fname, f in feature_extractors.items():
68
125
  f = _get_underlying_func(f)
69
126
  pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
70
- assert type(self) in pe_type
127
+ if type(self) not in pe_type:
128
+ raise TypeError(
129
+ f"Feature '{fname}' cannot be a child of {type(self).__name__}"
130
+ )
71
131
  return feature_extractors
72
132
 
73
- def _check_is_trainable(self, feature_extractors):
74
- is_trainable = False
133
+ def _check_is_trainable(self, feature_extractors: dict) -> bool:
134
+ """Check if any of the contained features are trainable."""
75
135
  for fname, f in feature_extractors.items():
76
136
  if isinstance(f, FeatureExtractor):
77
- is_trainable = f._is_trainable
78
- else:
79
- f = _get_underlying_func(f)
80
- if isinstance(f, TrainableFeature):
81
- is_trainable = True
82
- if is_trainable:
83
- break
84
- return is_trainable
137
+ if f._is_trainable:
138
+ return True
139
+ elif isinstance(_get_underlying_func(f), TrainableFeature):
140
+ return True
141
+ return False
85
142
 
86
143
  def preprocess(self, *x, **kwargs):
144
+ """Apply pre-processing to the input data.
145
+
146
+ Parameters
147
+ ----------
148
+ *x : tuple
149
+ Input data.
150
+ **kwargs
151
+ Additional keyword arguments.
152
+
153
+ Returns
154
+ -------
155
+ tuple
156
+ The pre-processed data.
157
+
158
+ """
87
159
  return (*x,)
88
160
 
89
- def __call__(self, *x, _batch_size=None, _ch_names=None):
161
+ def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
162
+ """Apply all feature extractors to the input data.
163
+
164
+ Parameters
165
+ ----------
166
+ *x : tuple
167
+ Input data.
168
+ _batch_size : int, optional
169
+ The number of samples in the batch.
170
+ _ch_names : list of str, optional
171
+ The names of the channels in the input data.
172
+
173
+ Returns
174
+ -------
175
+ dict
176
+ A dictionary where keys are feature names and values are the
177
+ computed feature values.
178
+
179
+ """
90
180
  assert _batch_size is not None
91
181
  assert _ch_names is not None
92
182
  if self._is_trainable:
@@ -100,59 +190,83 @@ class FeatureExtractor(TrainableFeature):
100
190
  r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
101
191
  else:
102
192
  r = f(*z)
103
- f = _get_underlying_func(f)
104
- if hasattr(f, "feature_kind"):
105
- r = f.feature_kind(r, _ch_names=_ch_names)
193
+ f_und = _get_underlying_func(f)
194
+ if hasattr(f_und, "feature_kind"):
195
+ r = f_und.feature_kind(r, _ch_names=_ch_names)
106
196
  if not isinstance(fname, str) or not fname:
107
- if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
108
- fname = ""
109
- else:
110
- fname = f.__name__
197
+ fname = getattr(f_und, "__name__", "")
111
198
  if isinstance(r, dict):
112
- if fname:
113
- fname += "_"
199
+ prefix = f"{fname}_" if fname else ""
114
200
  for k, v in r.items():
115
- self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
201
+ self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size)
116
202
  else:
117
203
  self._add_feature_to_dict(results_dict, fname, r, _batch_size)
118
204
  return results_dict
119
205
 
120
- def _add_feature_to_dict(self, results_dict, name, value, batch_size):
121
- if not isinstance(value, np.ndarray):
122
- results_dict[name] = value
123
- else:
206
+ def _add_feature_to_dict(
207
+ self, results_dict: dict, name: str, value: any, batch_size: int
208
+ ):
209
+ """Add a computed feature to the results dictionary."""
210
+ if isinstance(value, np.ndarray):
124
211
  assert value.shape[0] == batch_size
125
- results_dict[name] = value
212
+ results_dict[name] = value
126
213
 
127
214
  def clear(self):
215
+ """Clear the state of all trainable sub-features."""
128
216
  if not self._is_trainable:
129
217
  return
130
- for fname, f in self.feature_extractors_dict.items():
131
- f = _get_underlying_func(f)
132
- if isinstance(f, TrainableFeature):
133
- f.clear()
218
+ for f in self.feature_extractors_dict.values():
219
+ if isinstance(_get_underlying_func(f), TrainableFeature):
220
+ _get_underlying_func(f).clear()
134
221
 
135
222
  def partial_fit(self, *x, y=None):
223
+ """Partially fit all trainable sub-features."""
136
224
  if not self._is_trainable:
137
225
  return
138
226
  z = self.preprocess(*x, **self.preprocess_kwargs)
139
- for fname, f in self.feature_extractors_dict.items():
140
- f = _get_underlying_func(f)
141
- if isinstance(f, TrainableFeature):
142
- f.partial_fit(*z, y=y)
227
+ if not isinstance(z, tuple):
228
+ z = (z,)
229
+ for f in self.feature_extractors_dict.values():
230
+ if isinstance(_get_underlying_func(f), TrainableFeature):
231
+ _get_underlying_func(f).partial_fit(*z, y=y)
143
232
 
144
233
  def fit(self):
234
+ """Fit all trainable sub-features."""
145
235
  if not self._is_trainable:
146
236
  return
147
- for fname, f in self.feature_extractors_dict.items():
148
- f = _get_underlying_func(f)
149
- if isinstance(f, TrainableFeature):
237
+ for f in self.feature_extractors_dict.values():
238
+ if isinstance(_get_underlying_func(f), TrainableFeature):
150
239
  f.fit()
151
240
  super().fit()
152
241
 
153
242
 
154
243
  class MultivariateFeature:
155
- def __call__(self, x, _ch_names=None):
244
+ """A mixin for features that operate on multiple channels.
245
+
246
+ This class provides a `__call__` method that converts a feature array into
247
+ a dictionary with named features, where names are derived from channel
248
+ names.
249
+ """
250
+
251
+ def __call__(
252
+ self, x: np.ndarray, _ch_names: list[str] | None = None
253
+ ) -> dict | np.ndarray:
254
+ """Convert a feature array to a named dictionary.
255
+
256
+ Parameters
257
+ ----------
258
+ x : numpy.ndarray
259
+ The computed feature array.
260
+ _ch_names : list of str, optional
261
+ The list of channel names.
262
+
263
+ Returns
264
+ -------
265
+ dict or numpy.ndarray
266
+ A dictionary of named features, or the original array if feature
267
+ channel names cannot be generated.
268
+
269
+ """
156
270
  assert _ch_names is not None
157
271
  f_channels = self.feature_channel_names(_ch_names)
158
272
  if isinstance(x, dict):
@@ -163,37 +277,66 @@ class MultivariateFeature:
163
277
  return self._array_to_dict(x, f_channels)
164
278
 
165
279
  @staticmethod
166
- def _array_to_dict(x, f_channels, name=""):
280
+ def _array_to_dict(
281
+ x: np.ndarray, f_channels: list[str], name: str = ""
282
+ ) -> dict | np.ndarray:
283
+ """Convert a numpy array to a dictionary with named keys."""
167
284
  assert isinstance(x, np.ndarray)
168
- if len(f_channels) == 0:
169
- assert x.ndim == 1
170
- if name:
171
- return {name: x}
172
- return x
173
- assert x.shape[1] == len(f_channels)
285
+ if not f_channels:
286
+ return {name: x} if name else x
287
+ assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}"
174
288
  x = x.swapaxes(0, 1)
175
- names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
289
+ prefix = f"{name}_" if name else ""
290
+ names = [f"{prefix}{ch}" for ch in f_channels]
176
291
  return dict(zip(names, x))
177
292
 
178
- def feature_channel_names(self, ch_names):
293
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
294
+ """Generate feature names based on channel names.
295
+
296
+ Parameters
297
+ ----------
298
+ ch_names : list of str
299
+ The names of the input channels.
300
+
301
+ Returns
302
+ -------
303
+ list of str
304
+ The names for the output features.
305
+
306
+ """
179
307
  return []
180
308
 
181
309
 
182
310
  class UnivariateFeature(MultivariateFeature):
183
- def feature_channel_names(self, ch_names):
311
+ """A feature kind for operations applied to each channel independently."""
312
+
313
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
314
+ """Return the channel names themselves as feature names."""
184
315
  return ch_names
185
316
 
186
317
 
187
318
  class BivariateFeature(MultivariateFeature):
188
- def __init__(self, *args, channel_pair_format="{}<>{}"):
319
+ """A feature kind for operations on pairs of channels.
320
+
321
+ Parameters
322
+ ----------
323
+ channel_pair_format : str, default="{}<>{}"
324
+ A format string used to create feature names from pairs of
325
+ channel names.
326
+
327
+ """
328
+
329
+ def __init__(self, *args, channel_pair_format: str = "{}<>{}"):
189
330
  super().__init__(*args)
190
331
  self.channel_pair_format = channel_pair_format
191
332
 
192
333
  @staticmethod
193
- def get_pair_iterators(n):
334
+ def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]:
335
+ """Get indices for unique, unordered pairs of channels."""
194
336
  return np.triu_indices(n, 1)
195
337
 
196
- def feature_channel_names(self, ch_names):
338
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
339
+ """Generate feature names for each pair of channels."""
197
340
  return [
198
341
  self.channel_pair_format.format(ch_names[i], ch_names[j])
199
342
  for i, j in zip(*self.get_pair_iterators(len(ch_names)))
@@ -201,8 +344,11 @@ class BivariateFeature(MultivariateFeature):
201
344
 
202
345
 
203
346
  class DirectedBivariateFeature(BivariateFeature):
347
+ """A feature kind for directed operations on pairs of channels."""
348
+
204
349
  @staticmethod
205
- def get_pair_iterators(n):
350
+ def get_pair_iterators(n: int) -> list[np.ndarray]:
351
+ """Get indices for all ordered pairs of channels (excluding self-pairs)."""
206
352
  return [
207
353
  np.append(a, b)
208
354
  for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
@@ -5,7 +5,27 @@ from . import extractors, feature_bank
5
5
  from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
6
6
 
7
7
 
8
- def get_feature_predecessors(feature_or_extractor: Callable):
8
+ def get_feature_predecessors(feature_or_extractor: Callable) -> list:
9
+ """Get the dependency hierarchy for a feature or feature extractor.
10
+
11
+ This function recursively traverses the `parent_extractor_type` attribute
12
+ of a feature or extractor to build a list representing its dependency
13
+ lineage.
14
+
15
+ Parameters
16
+ ----------
17
+ feature_or_extractor : callable
18
+ The feature function or :class:`FeatureExtractor` class to inspect.
19
+
20
+ Returns
21
+ -------
22
+ list
23
+ A nested list representing the dependency tree. For a simple linear
24
+ chain, this will be a flat list from the specific feature up to the
25
+ base `FeatureExtractor`. For multiple dependencies, it will contain
26
+ tuples of sub-dependencies.
27
+
28
+ """
9
29
  current = _get_underlying_func(feature_or_extractor)
10
30
  if current is FeatureExtractor:
11
31
  return [current]
@@ -20,18 +40,59 @@ def get_feature_predecessors(feature_or_extractor: Callable):
20
40
  return [current, tuple(predecessors)]
21
41
 
22
42
 
23
- def get_feature_kind(feature: Callable):
43
+ def get_feature_kind(feature: Callable) -> MultivariateFeature:
44
+ """Get the 'kind' of a feature function.
45
+
46
+ The feature kind (e.g., univariate, bivariate) is typically attached by a
47
+ decorator.
48
+
49
+ Parameters
50
+ ----------
51
+ feature : callable
52
+ The feature function to inspect.
53
+
54
+ Returns
55
+ -------
56
+ MultivariateFeature
57
+ An instance of the feature kind (e.g., `UnivariateFeature()`).
58
+
59
+ """
24
60
  return _get_underlying_func(feature).feature_kind
25
61
 
26
62
 
27
- def get_all_features():
63
+ def get_all_features() -> list[tuple[str, Callable]]:
64
+ """Get a list of all available feature functions.
65
+
66
+ Scans the `eegdash.features.feature_bank` module for functions that have
67
+ been decorated to have a `feature_kind` attribute.
68
+
69
+ Returns
70
+ -------
71
+ list[tuple[str, callable]]
72
+ A list of (name, function) tuples for all discovered features.
73
+
74
+ """
75
+
28
76
  def isfeature(x):
29
77
  return hasattr(_get_underlying_func(x), "feature_kind")
30
78
 
31
79
  return inspect.getmembers(feature_bank, isfeature)
32
80
 
33
81
 
34
- def get_all_feature_extractors():
82
+ def get_all_feature_extractors() -> list[tuple[str, type[FeatureExtractor]]]:
83
+ """Get a list of all available `FeatureExtractor` classes.
84
+
85
+ Scans the `eegdash.features.feature_bank` module for all classes that
86
+ subclass :class:`~eegdash.features.extractors.FeatureExtractor`.
87
+
88
+ Returns
89
+ -------
90
+ list[tuple[str, type[FeatureExtractor]]]
91
+ A list of (name, class) tuples for all discovered feature extractors,
92
+ including the base `FeatureExtractor` itself.
93
+
94
+ """
95
+
35
96
  def isfeatureextractor(x):
36
97
  return inspect.isclass(x) and issubclass(x, FeatureExtractor)
37
98
 
@@ -41,7 +102,19 @@ def get_all_feature_extractors():
41
102
  ]
42
103
 
43
104
 
44
- def get_all_feature_kinds():
105
+ def get_all_feature_kinds() -> list[tuple[str, type[MultivariateFeature]]]:
106
+ """Get a list of all available feature 'kind' classes.
107
+
108
+ Scans the `eegdash.features.extractors` module for all classes that
109
+ subclass :class:`~eegdash.features.extractors.MultivariateFeature`.
110
+
111
+ Returns
112
+ -------
113
+ list[tuple[str, type[MultivariateFeature]]]
114
+ A list of (name, class) tuples for all discovered feature kinds.
115
+
116
+ """
117
+
45
118
  def isfeaturekind(x):
46
119
  return inspect.isclass(x) and issubclass(x, MultivariateFeature)
47
120