eegdash 0.4.0.dev173498563__py3-none-any.whl → 0.4.1.dev185__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -10,8 +10,31 @@ from .extractors import (
10
10
  _get_underlying_func,
11
11
  )
12
12
 
13
+ __all__ = [
14
+ "bivariate_feature",
15
+ "FeatureKind",
16
+ "FeaturePredecessor",
17
+ "multivariate_feature",
18
+ "univariate_feature",
19
+ ]
20
+
13
21
 
14
22
  class FeaturePredecessor:
23
+ """A decorator to specify parent extractors for a feature function.
24
+
25
+ This decorator attaches a list of parent extractor types to a feature
26
+ extraction function. This information can be used to build a dependency
27
+ graph of features.
28
+
29
+ Parameters
30
+ ----------
31
+ *parent_extractor_type : list of Type
32
+ A list of feature extractor classes (subclasses of
33
+ :class:`~eegdash.features.extractors.FeatureExtractor`) that this
34
+ feature depends on.
35
+
36
+ """
37
+
15
38
  def __init__(self, *parent_extractor_type: List[Type]):
16
39
  parent_cls = parent_extractor_type
17
40
  if not parent_cls:
@@ -20,17 +43,58 @@ class FeaturePredecessor:
20
43
  assert issubclass(p_cls, FeatureExtractor)
21
44
  self.parent_extractor_type = parent_cls
22
45
 
23
- def __call__(self, func: Callable):
46
+ def __call__(self, func: Callable) -> Callable:
47
+ """Apply the decorator to a function.
48
+
49
+ Parameters
50
+ ----------
51
+ func : callable
52
+ The feature extraction function to decorate.
53
+
54
+ Returns
55
+ -------
56
+ callable
57
+ The decorated function with the `parent_extractor_type` attribute
58
+ set.
59
+
60
+ """
24
61
  f = _get_underlying_func(func)
25
62
  f.parent_extractor_type = self.parent_extractor_type
26
63
  return func
27
64
 
28
65
 
29
66
  class FeatureKind:
67
+ """A decorator to specify the kind of a feature.
68
+
69
+ This decorator attaches a "feature kind" (e.g., univariate, bivariate)
70
+ to a feature extraction function.
71
+
72
+ Parameters
73
+ ----------
74
+ feature_kind : ~eegdash.features.extractors.MultivariateFeature
75
+ An instance of a feature kind class, such as
76
+ :class:`~eegdash.features.extractors.UnivariateFeature` or
77
+ :class:`~eegdash.features.extractors.BivariateFeature`.
78
+
79
+ """
80
+
30
81
  def __init__(self, feature_kind: MultivariateFeature):
31
82
  self.feature_kind = feature_kind
32
83
 
33
- def __call__(self, func):
84
+ def __call__(self, func: Callable) -> Callable:
85
+ """Apply the decorator to a function.
86
+
87
+ Parameters
88
+ ----------
89
+ func : callable
90
+ The feature extraction function to decorate.
91
+
92
+ Returns
93
+ -------
94
+ callable
95
+ The decorated function with the `feature_kind` attribute set.
96
+
97
+ """
34
98
  f = _get_underlying_func(func)
35
99
  f.feature_kind = self.feature_kind
36
100
  return func
@@ -38,9 +102,33 @@ class FeatureKind:
38
102
 
39
103
  # Syntax sugar
40
104
  univariate_feature = FeatureKind(UnivariateFeature())
105
+ """Decorator to mark a feature as univariate.
106
+
107
+ This is a convenience instance of :class:`~eegdash.features.decorators.FeatureKind` pre-configured for
108
+ univariate features.
109
+ """
41
110
 
42
111
 
43
- def bivariate_feature(func, directed=False):
112
+ def bivariate_feature(func: Callable, directed: bool = False) -> Callable:
113
+ """Decorator to mark a feature as bivariate.
114
+
115
+ This decorator specifies that the feature operates on pairs of channels.
116
+
117
+ Parameters
118
+ ----------
119
+ func : callable
120
+ The feature extraction function to decorate.
121
+ directed : bool, default False
122
+ If True, the feature is directed (e.g., connectivity from channel A
123
+ to B is different from B to A). If False, the feature is undirected.
124
+
125
+ Returns
126
+ -------
127
+ callable
128
+ The decorated function with the appropriate bivariate feature kind
129
+ attached.
130
+
131
+ """
44
132
  if directed:
45
133
  kind = DirectedBivariateFeature()
46
134
  else:
@@ -49,3 +137,8 @@ def bivariate_feature(func, directed=False):
49
137
 
50
138
 
51
139
  multivariate_feature = FeatureKind(MultivariateFeature())
140
+ """Decorator to mark a feature as multivariate.
141
+
142
+ This is a convenience instance of :class:`~eegdash.features.decorators.FeatureKind` pre-configured for
143
+ multivariate features, which operate on all channels simultaneously.
144
+ """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from collections.abc import Callable
3
5
  from functools import partial
@@ -6,8 +8,33 @@ from typing import Dict
6
8
  import numpy as np
7
9
  from numba.core.dispatcher import Dispatcher
8
10
 
11
+ __all__ = [
12
+ "BivariateFeature",
13
+ "DirectedBivariateFeature",
14
+ "FeatureExtractor",
15
+ "MultivariateFeature",
16
+ "TrainableFeature",
17
+ "UnivariateFeature",
18
+ ]
19
+
20
+
21
+ def _get_underlying_func(func: Callable) -> Callable:
22
+ """Get the underlying function from a potential wrapper.
23
+
24
+ This helper unwraps functions that might be wrapped by `functools.partial`
25
+ or `numba.dispatcher.Dispatcher`.
26
+
27
+ Parameters
28
+ ----------
29
+ func : callable
30
+ The function to unwrap.
9
31
 
10
- def _get_underlying_func(func):
32
+ Returns
33
+ -------
34
+ callable
35
+ The underlying Python function.
36
+
37
+ """
11
38
  f = func
12
39
  if isinstance(f, partial):
13
40
  f = f.func
@@ -17,22 +44,46 @@ def _get_underlying_func(func):
17
44
 
18
45
 
19
46
  class TrainableFeature(ABC):
47
+ """Abstract base class for features that require training.
48
+
49
+ This ABC defines the interface for feature extractors that need to be
50
+ fitted on data before they can be used. It includes methods for fitting
51
+ the feature extractor and for resetting its state.
52
+ """
53
+
20
54
  def __init__(self):
21
55
  self._is_trained = False
22
56
  self.clear()
23
57
 
24
58
  @abstractmethod
25
59
  def clear(self):
60
+ """Reset the internal state of the feature extractor."""
26
61
  pass
27
62
 
28
63
  @abstractmethod
29
64
  def partial_fit(self, *x, y=None):
65
+ """Update the feature extractor's state with a batch of data.
66
+
67
+ Parameters
68
+ ----------
69
+ *x : tuple
70
+ The input data for fitting.
71
+ y : any, optional
72
+ The target data, if required for supervised training.
73
+
74
+ """
30
75
  pass
31
76
 
32
77
  def fit(self):
78
+ """Finalize the training of the feature extractor.
79
+
80
+ This method should be called after all data has been seen via
81
+ `partial_fit`. It marks the feature as fitted.
82
+ """
33
83
  self._is_fitted = True
34
84
 
35
85
  def __call__(self, *args, **kwargs):
86
+ """Check if the feature is fitted before execution."""
36
87
  if not self._is_fitted:
37
88
  raise RuntimeError(
38
89
  f"{self.__class__} cannot be called, it has to be fitted first."
@@ -40,6 +91,22 @@ class TrainableFeature(ABC):
40
91
 
41
92
 
42
93
  class FeatureExtractor(TrainableFeature):
94
+ """A composite feature extractor that applies multiple feature functions.
95
+
96
+ This class orchestrates the application of a dictionary of feature
97
+ extraction functions to input data. It can handle nested extractors,
98
+ pre-processing, and trainable features.
99
+
100
+ Parameters
101
+ ----------
102
+ feature_extractors : dict[str, callable]
103
+ A dictionary where keys are feature names and values are the feature
104
+ extraction functions or other `FeatureExtractor` instances.
105
+ **preprocess_kwargs
106
+ Keyword arguments to be passed to the `preprocess` method.
107
+
108
+ """
109
+
43
110
  def __init__(
44
111
  self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
45
112
  ):
@@ -63,30 +130,64 @@ class FeatureExtractor(TrainableFeature):
63
130
  if isinstance(fe, partial):
64
131
  self.features_kwargs[fn] = fe.keywords
65
132
 
66
- def _validate_execution_tree(self, feature_extractors):
133
+ def _validate_execution_tree(self, feature_extractors: dict) -> dict:
134
+ """Validate the feature dependency graph."""
67
135
  for fname, f in feature_extractors.items():
68
136
  f = _get_underlying_func(f)
69
137
  pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
70
- assert type(self) in pe_type
138
+ if type(self) not in pe_type:
139
+ raise TypeError(
140
+ f"Feature '{fname}' cannot be a child of {type(self).__name__}"
141
+ )
71
142
  return feature_extractors
72
143
 
73
- def _check_is_trainable(self, feature_extractors):
74
- is_trainable = False
144
+ def _check_is_trainable(self, feature_extractors: dict) -> bool:
145
+ """Check if any of the contained features are trainable."""
75
146
  for fname, f in feature_extractors.items():
76
147
  if isinstance(f, FeatureExtractor):
77
- is_trainable = f._is_trainable
78
- else:
79
- f = _get_underlying_func(f)
80
- if isinstance(f, TrainableFeature):
81
- is_trainable = True
82
- if is_trainable:
83
- break
84
- return is_trainable
148
+ if f._is_trainable:
149
+ return True
150
+ elif isinstance(_get_underlying_func(f), TrainableFeature):
151
+ return True
152
+ return False
85
153
 
86
154
  def preprocess(self, *x, **kwargs):
155
+ """Apply pre-processing to the input data.
156
+
157
+ Parameters
158
+ ----------
159
+ *x : tuple
160
+ Input data.
161
+ **kwargs
162
+ Additional keyword arguments.
163
+
164
+ Returns
165
+ -------
166
+ tuple
167
+ The pre-processed data.
168
+
169
+ """
87
170
  return (*x,)
88
171
 
89
- def __call__(self, *x, _batch_size=None, _ch_names=None):
172
+ def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
173
+ """Apply all feature extractors to the input data.
174
+
175
+ Parameters
176
+ ----------
177
+ *x : tuple
178
+ Input data.
179
+ _batch_size : int, optional
180
+ The number of samples in the batch.
181
+ _ch_names : list of str, optional
182
+ The names of the channels in the input data.
183
+
184
+ Returns
185
+ -------
186
+ dict
187
+ A dictionary where keys are feature names and values are the
188
+ computed feature values.
189
+
190
+ """
90
191
  assert _batch_size is not None
91
192
  assert _ch_names is not None
92
193
  if self._is_trainable:
@@ -100,59 +201,83 @@ class FeatureExtractor(TrainableFeature):
100
201
  r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
101
202
  else:
102
203
  r = f(*z)
103
- f = _get_underlying_func(f)
104
- if hasattr(f, "feature_kind"):
105
- r = f.feature_kind(r, _ch_names=_ch_names)
204
+ f_und = _get_underlying_func(f)
205
+ if hasattr(f_und, "feature_kind"):
206
+ r = f_und.feature_kind(r, _ch_names=_ch_names)
106
207
  if not isinstance(fname, str) or not fname:
107
- if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
108
- fname = ""
109
- else:
110
- fname = f.__name__
208
+ fname = getattr(f_und, "__name__", "")
111
209
  if isinstance(r, dict):
112
- if fname:
113
- fname += "_"
210
+ prefix = f"{fname}_" if fname else ""
114
211
  for k, v in r.items():
115
- self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
212
+ self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size)
116
213
  else:
117
214
  self._add_feature_to_dict(results_dict, fname, r, _batch_size)
118
215
  return results_dict
119
216
 
120
- def _add_feature_to_dict(self, results_dict, name, value, batch_size):
121
- if not isinstance(value, np.ndarray):
122
- results_dict[name] = value
123
- else:
217
+ def _add_feature_to_dict(
218
+ self, results_dict: dict, name: str, value: any, batch_size: int
219
+ ):
220
+ """Add a computed feature to the results dictionary."""
221
+ if isinstance(value, np.ndarray):
124
222
  assert value.shape[0] == batch_size
125
- results_dict[name] = value
223
+ results_dict[name] = value
126
224
 
127
225
  def clear(self):
226
+ """Clear the state of all trainable sub-features."""
128
227
  if not self._is_trainable:
129
228
  return
130
- for fname, f in self.feature_extractors_dict.items():
131
- f = _get_underlying_func(f)
132
- if isinstance(f, TrainableFeature):
133
- f.clear()
229
+ for f in self.feature_extractors_dict.values():
230
+ if isinstance(_get_underlying_func(f), TrainableFeature):
231
+ _get_underlying_func(f).clear()
134
232
 
135
233
  def partial_fit(self, *x, y=None):
234
+ """Partially fit all trainable sub-features."""
136
235
  if not self._is_trainable:
137
236
  return
138
237
  z = self.preprocess(*x, **self.preprocess_kwargs)
139
- for fname, f in self.feature_extractors_dict.items():
140
- f = _get_underlying_func(f)
141
- if isinstance(f, TrainableFeature):
142
- f.partial_fit(*z, y=y)
238
+ if not isinstance(z, tuple):
239
+ z = (z,)
240
+ for f in self.feature_extractors_dict.values():
241
+ if isinstance(_get_underlying_func(f), TrainableFeature):
242
+ _get_underlying_func(f).partial_fit(*z, y=y)
143
243
 
144
244
  def fit(self):
245
+ """Fit all trainable sub-features."""
145
246
  if not self._is_trainable:
146
247
  return
147
- for fname, f in self.feature_extractors_dict.items():
148
- f = _get_underlying_func(f)
149
- if isinstance(f, TrainableFeature):
248
+ for f in self.feature_extractors_dict.values():
249
+ if isinstance(_get_underlying_func(f), TrainableFeature):
150
250
  f.fit()
151
251
  super().fit()
152
252
 
153
253
 
154
254
  class MultivariateFeature:
155
- def __call__(self, x, _ch_names=None):
255
+ """A mixin for features that operate on multiple channels.
256
+
257
+ This class provides a `__call__` method that converts a feature array into
258
+ a dictionary with named features, where names are derived from channel
259
+ names.
260
+ """
261
+
262
+ def __call__(
263
+ self, x: np.ndarray, _ch_names: list[str] | None = None
264
+ ) -> dict | np.ndarray:
265
+ """Convert a feature array to a named dictionary.
266
+
267
+ Parameters
268
+ ----------
269
+ x : numpy.ndarray
270
+ The computed feature array.
271
+ _ch_names : list of str, optional
272
+ The list of channel names.
273
+
274
+ Returns
275
+ -------
276
+ dict or numpy.ndarray
277
+ A dictionary of named features, or the original array if feature
278
+ channel names cannot be generated.
279
+
280
+ """
156
281
  assert _ch_names is not None
157
282
  f_channels = self.feature_channel_names(_ch_names)
158
283
  if isinstance(x, dict):
@@ -163,37 +288,66 @@ class MultivariateFeature:
163
288
  return self._array_to_dict(x, f_channels)
164
289
 
165
290
  @staticmethod
166
- def _array_to_dict(x, f_channels, name=""):
291
+ def _array_to_dict(
292
+ x: np.ndarray, f_channels: list[str], name: str = ""
293
+ ) -> dict | np.ndarray:
294
+ """Convert a numpy array to a dictionary with named keys."""
167
295
  assert isinstance(x, np.ndarray)
168
- if len(f_channels) == 0:
169
- assert x.ndim == 1
170
- if name:
171
- return {name: x}
172
- return x
173
- assert x.shape[1] == len(f_channels)
296
+ if not f_channels:
297
+ return {name: x} if name else x
298
+ assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}"
174
299
  x = x.swapaxes(0, 1)
175
- names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
300
+ prefix = f"{name}_" if name else ""
301
+ names = [f"{prefix}{ch}" for ch in f_channels]
176
302
  return dict(zip(names, x))
177
303
 
178
- def feature_channel_names(self, ch_names):
304
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
305
+ """Generate feature names based on channel names.
306
+
307
+ Parameters
308
+ ----------
309
+ ch_names : list of str
310
+ The names of the input channels.
311
+
312
+ Returns
313
+ -------
314
+ list of str
315
+ The names for the output features.
316
+
317
+ """
179
318
  return []
180
319
 
181
320
 
182
321
  class UnivariateFeature(MultivariateFeature):
183
- def feature_channel_names(self, ch_names):
322
+ """A feature kind for operations applied to each channel independently."""
323
+
324
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
325
+ """Return the channel names themselves as feature names."""
184
326
  return ch_names
185
327
 
186
328
 
187
329
  class BivariateFeature(MultivariateFeature):
188
- def __init__(self, *args, channel_pair_format="{}<>{}"):
330
+ """A feature kind for operations on pairs of channels.
331
+
332
+ Parameters
333
+ ----------
334
+ channel_pair_format : str, default="{}<>{}"
335
+ A format string used to create feature names from pairs of
336
+ channel names.
337
+
338
+ """
339
+
340
+ def __init__(self, *args, channel_pair_format: str = "{}<>{}"):
189
341
  super().__init__(*args)
190
342
  self.channel_pair_format = channel_pair_format
191
343
 
192
344
  @staticmethod
193
- def get_pair_iterators(n):
345
+ def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]:
346
+ """Get indices for unique, unordered pairs of channels."""
194
347
  return np.triu_indices(n, 1)
195
348
 
196
- def feature_channel_names(self, ch_names):
349
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
350
+ """Generate feature names for each pair of channels."""
197
351
  return [
198
352
  self.channel_pair_format.format(ch_names[i], ch_names[j])
199
353
  for i, j in zip(*self.get_pair_iterators(len(ch_names)))
@@ -201,8 +355,11 @@ class BivariateFeature(MultivariateFeature):
201
355
 
202
356
 
203
357
  class DirectedBivariateFeature(BivariateFeature):
358
+ """A feature kind for directed operations on pairs of channels."""
359
+
204
360
  @staticmethod
205
- def get_pair_iterators(n):
361
+ def get_pair_iterators(n: int) -> list[np.ndarray]:
362
+ """Get indices for all ordered pairs of channels (excluding self-pairs)."""
206
363
  return [
207
364
  np.append(a, b)
208
365
  for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
@@ -36,8 +36,12 @@ class EntropyFeatureExtractor(FeatureExtractor):
36
36
  counts_m = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // l))
37
37
  counts_mp1 = np.empty((*x.shape[:-1], (x.shape[-1] - m) // l))
38
38
  for i in np.ndindex(x.shape[:-1]):
39
- counts_m[*i, :] = _channel_app_samp_entropy_counts(x[i], m, rr[i], l)
40
- counts_mp1[*i, :] = _channel_app_samp_entropy_counts(x[i], m + 1, rr[i], l)
39
+ counts_m[i + (slice(None),)] = _channel_app_samp_entropy_counts(
40
+ x[i], m, rr[i], l
41
+ )
42
+ counts_mp1[i + (slice(None),)] = _channel_app_samp_entropy_counts(
43
+ x[i], m + 1, rr[i], l
44
+ )
41
45
  return counts_m, counts_mp1
42
46
 
43
47
 
@@ -62,7 +66,7 @@ def complexity_sample_entropy(counts_m, counts_mp1):
62
66
  def complexity_svd_entropy(x, m=10, tau=1):
63
67
  x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
64
68
  for i in np.ndindex(x.shape[:-1]):
65
- x_emb[*i, :, :] = _create_embedding(x[i], m, tau)
69
+ x_emb[i + (slice(None), slice(None))] = _create_embedding(x[i], m, tau)
66
70
  s = np.linalg.svdvals(x_emb)
67
71
  s /= s.sum(axis=-1, keepdims=True)
68
72
  return -np.sum(s * np.log(s), axis=-1)
@@ -26,7 +26,7 @@ def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
26
26
  for i in np.ndindex(x.shape[:-1]):
27
27
  for k in range(1, k_max + 1):
28
28
  for m in range(k):
29
- L_km[m] = np.mean(np.abs(np.diff(x[*i, m:], n=k)))
29
+ L_km[m] = np.mean(np.abs(np.diff(x[i + (slice(m, None),)], n=k)))
30
30
  L_k[k - 1] = (N - 1) * np.sum(L_km[:k]) / (k**3)
31
31
  L_k = np.maximum(L_k, eps)
32
32
  hfd[i] = np.linalg.lstsq(log_k, np.log(L_k))[0][0]
@@ -8,20 +8,21 @@ from ..extractors import FeatureExtractor
8
8
 
9
9
  __all__ = [
10
10
  "HilbertFeatureExtractor",
11
- "signal_mean",
12
- "signal_variance",
13
- "signal_skewness",
11
+ "SIGNAL_PREDECESSORS",
12
+ "signal_decorrelation_time",
13
+ "signal_hjorth_activity",
14
+ "signal_hjorth_complexity",
15
+ "signal_hjorth_mobility",
14
16
  "signal_kurtosis",
15
- "signal_std",
16
- "signal_root_mean_square",
17
+ "signal_line_length",
18
+ "signal_mean",
17
19
  "signal_peak_to_peak",
18
20
  "signal_quantile",
21
+ "signal_root_mean_square",
22
+ "signal_skewness",
23
+ "signal_std",
24
+ "signal_variance",
19
25
  "signal_zero_crossings",
20
- "signal_line_length",
21
- "signal_hjorth_activity",
22
- "signal_hjorth_mobility",
23
- "signal_hjorth_complexity",
24
- "signal_decorrelation_time",
25
26
  ]
26
27
 
27
28
 
@@ -1,5 +1,13 @@
1
1
  import numpy as np
2
2
 
3
+ __all__ = [
4
+ "DEFAULT_FREQ_BANDS",
5
+ "get_valid_freq_band",
6
+ "reduce_freq_bands",
7
+ "slice_freq_band",
8
+ ]
9
+
10
+
3
11
  DEFAULT_FREQ_BANDS = {
4
12
  "delta": (1, 4.5),
5
13
  "theta": (4.5, 8),