eegdash 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,25 +1,53 @@
1
- # Features datasets
2
1
  from .datasets import FeaturesConcatDataset, FeaturesDataset
3
2
  from .decorators import (
4
3
  FeatureKind,
5
4
  FeaturePredecessor,
6
5
  bivariate_feature,
7
- directed_bivariate_feature,
8
6
  multivariate_feature,
9
7
  univariate_feature,
10
8
  )
11
-
12
- # Feature extraction
13
9
  from .extractors import (
14
10
  BivariateFeature,
15
11
  DirectedBivariateFeature,
16
12
  FeatureExtractor,
17
- FitableFeature,
18
13
  MultivariateFeature,
14
+ TrainableFeature,
19
15
  UnivariateFeature,
20
16
  )
21
-
22
- # Features:
23
- from .feature_bank import *
17
+ from .feature_bank import * # noqa: F401
18
+ from .inspect import (
19
+ get_all_feature_extractors,
20
+ get_all_feature_kinds,
21
+ get_all_features,
22
+ get_feature_kind,
23
+ get_feature_predecessors,
24
+ )
24
25
  from .serialization import load_features_concat_dataset
25
- from .utils import extract_features, fit_feature_extractors
26
+ from .utils import (
27
+ extract_features,
28
+ fit_feature_extractors,
29
+ )
30
+
31
+ __all__ = [
32
+ "FeaturesConcatDataset",
33
+ "FeaturesDataset",
34
+ "FeatureKind",
35
+ "FeaturePredecessor",
36
+ "bivariate_feature",
37
+ "multivariate_feature",
38
+ "univariate_feature",
39
+ "BivariateFeature",
40
+ "DirectedBivariateFeature",
41
+ "FeatureExtractor",
42
+ "MultivariateFeature",
43
+ "TrainableFeature",
44
+ "UnivariateFeature",
45
+ "get_all_feature_extractors",
46
+ "get_all_feature_kinds",
47
+ "get_all_features",
48
+ "get_feature_kind",
49
+ "get_feature_predecessors",
50
+ "load_features_concat_dataset",
51
+ "extract_features",
52
+ "fit_feature_extractors",
53
+ ]
@@ -4,8 +4,8 @@ import json
4
4
  import os
5
5
  import shutil
6
6
  import warnings
7
- from collections.abc import Callable, Iterable
8
- from typing import Dict, no_type_check
7
+ from collections.abc import Callable
8
+ from typing import Dict, List
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
@@ -33,6 +33,7 @@ class FeaturesDataset(EEGWindowsDataset):
33
33
  Holds additional description about the continuous signal / subject.
34
34
  transform : callable | None
35
35
  On-the-fly transform applied to the example before it is returned.
36
+
36
37
  """
37
38
 
38
39
  def __init__(
@@ -95,10 +96,12 @@ def _compute_stats(
95
96
  return tuple(res)
96
97
 
97
98
 
98
- def _pooled_var(counts, means, variances, ddof):
99
+ def _pooled_var(counts, means, variances, ddof, ddof_in=None):
100
+ if ddof_in is None:
101
+ ddof_in = ddof
99
102
  count = counts.sum(axis=0)
100
103
  mean = np.sum((counts / count) * means, axis=0)
101
- var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
104
+ var = np.sum(((counts - ddof_in) / (count - ddof)) * variances, axis=0)
102
105
  var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
103
106
  var[:] -= (count / (count - ddof)) * (mean**2)
104
107
  var[:] = var.clip(min=0)
@@ -156,6 +159,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
156
159
  splits : dict
157
160
  A dictionary with the name of the split (a string) as key and the
158
161
  dataset as value.
162
+
159
163
  """
160
164
  if isinstance(by, str):
161
165
  split_ids = {
@@ -187,6 +191,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
187
191
  DataFrame containing as many rows as there are windows in the
188
192
  BaseConcatDataset, with the metadata and description information
189
193
  for each window.
194
+
190
195
  """
191
196
  if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
192
197
  raise TypeError(
@@ -238,6 +243,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
238
243
  concat. This is useful in the setting of very large datasets, where
239
244
  one dataset has to be processed and saved at a time to account for
240
245
  its original position.
246
+
241
247
  """
242
248
  if len(self.datasets) == 0:
243
249
  raise ValueError("Expect at least one dataset")
@@ -323,25 +329,53 @@ class FeaturesConcatDataset(BaseConcatDataset):
323
329
  json.dump(kwargs, f)
324
330
 
325
331
  def to_dataframe(
326
- self, include_metadata=False, include_target=False, include_crop_inds=False
332
+ self,
333
+ include_metadata: bool | str | List[str] = False,
334
+ include_target: bool = False,
335
+ include_crop_inds: bool = False,
327
336
  ):
328
- if include_metadata or (include_target and include_crop_inds):
337
+ if (
338
+ not isinstance(include_metadata, bool)
339
+ or include_metadata
340
+ or include_crop_inds
341
+ ):
342
+ include_dataset = False
343
+ if isinstance(include_metadata, bool) and include_metadata:
344
+ include_dataset = True
345
+ cols = self.datasets[0].metadata.columns
346
+ else:
347
+ cols = include_metadata
348
+ if isinstance(cols, bool) and not cols:
349
+ cols = []
350
+ elif isinstance(cols, str):
351
+ cols = [cols]
352
+ cols = set(cols)
353
+ if include_crop_inds:
354
+ cols = {
355
+ "i_dataset",
356
+ "i_window_in_trial",
357
+ "i_start_in_trial",
358
+ "i_stop_in_trial",
359
+ *cols,
360
+ }
361
+ if include_target:
362
+ cols.add("target")
363
+ cols = list(cols)
364
+ include_dataset = "i_dataset" in cols
365
+ if include_dataset:
366
+ cols.remove("i_dataset")
329
367
  dataframes = [
330
- ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
368
+ ds.metadata[cols].join(ds.features, how="right", lsuffix="_metadata")
331
369
  for ds in self.datasets
332
370
  ]
371
+ if include_dataset:
372
+ for i, df in enumerate(dataframes):
373
+ df.insert(loc=0, column="i_dataset", value=i)
333
374
  elif include_target:
334
375
  dataframes = [
335
376
  ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
336
377
  for ds in self.datasets
337
378
  ]
338
- elif include_crop_inds:
339
- dataframes = [
340
- ds.metadata.drop("target", axis="columns").join(
341
- ds.features, how="right", lsuffix="_metadata"
342
- )
343
- for ds in self.datasets
344
- ]
345
379
  else:
346
380
  dataframes = [ds.features for ds in self.datasets]
347
381
  return pd.concat(dataframes, axis=0, ignore_index=True)
@@ -377,7 +411,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
377
411
  return_count=True,
378
412
  return_mean=True,
379
413
  return_var=True,
380
- ddof=ddof,
414
+ ddof=0,
381
415
  numeric_only=numeric_only,
382
416
  )
383
417
  for ds in self.datasets
@@ -387,11 +421,13 @@ class FeaturesConcatDataset(BaseConcatDataset):
387
421
  np.array([s[1] for s in stats]),
388
422
  np.array([s[2] for s in stats]),
389
423
  )
390
- _, _, var = _pooled_var(counts, means, variances, ddof)
424
+ _, _, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
391
425
  return pd.Series(var, index=self._numeric_columns())
392
426
 
393
- def std(self, ddof=1, numeric_only=False, n_jobs=1):
394
- return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
427
+ def std(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
428
+ return np.sqrt(
429
+ self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs) + eps
430
+ )
395
431
 
396
432
  def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
397
433
  stats = Parallel(n_jobs)(
@@ -400,7 +436,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
400
436
  return_count=True,
401
437
  return_mean=True,
402
438
  return_var=True,
403
- ddof=ddof,
439
+ ddof=0,
404
440
  numeric_only=numeric_only,
405
441
  )
406
442
  for ds in self.datasets
@@ -410,8 +446,8 @@ class FeaturesConcatDataset(BaseConcatDataset):
410
446
  np.array([s[1] for s in stats]),
411
447
  np.array([s[2] for s in stats]),
412
448
  )
413
- _, mean, var = _pooled_var(counts, means, variances, ddof)
414
- std = np.sqrt(var) + eps
449
+ _, mean, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
450
+ std = np.sqrt(var + eps)
415
451
  for ds in self.datasets:
416
452
  ds.features = (ds.features - mean) / std
417
453
 
@@ -38,6 +38,14 @@ class FeatureKind:
38
38
 
39
39
  # Syntax sugar
40
40
  univariate_feature = FeatureKind(UnivariateFeature())
41
- bivariate_feature = FeatureKind(BivariateFeature())
42
- directed_bivariate_feature = FeatureKind(DirectedBivariateFeature())
41
+
42
+
43
+ def bivariate_feature(func, directed=False):
44
+ if directed:
45
+ kind = DirectedBivariateFeature()
46
+ else:
47
+ kind = BivariateFeature()
48
+ return FeatureKind(kind)(func)
49
+
50
+
43
51
  multivariate_feature = FeatureKind(MultivariateFeature())
@@ -16,9 +16,9 @@ def _get_underlying_func(func):
16
16
  return f
17
17
 
18
18
 
19
- class FitableFeature(ABC):
19
+ class TrainableFeature(ABC):
20
20
  def __init__(self):
21
- self._is_fitted = False
21
+ self._is_trained = False
22
22
  self.clear()
23
23
 
24
24
  @abstractmethod
@@ -39,12 +39,12 @@ class FitableFeature(ABC):
39
39
  )
40
40
 
41
41
 
42
- class FeatureExtractor(FitableFeature):
42
+ class FeatureExtractor(TrainableFeature):
43
43
  def __init__(
44
44
  self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
45
45
  ):
46
46
  self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
47
- self._is_fitable = self._check_is_fitable(feature_extractors)
47
+ self._is_trainable = self._check_is_trainable(feature_extractors)
48
48
  super().__init__()
49
49
 
50
50
  # bypassing FeaturePredecessor to avoid circular import
@@ -70,32 +70,31 @@ class FeatureExtractor(FitableFeature):
70
70
  assert type(self) in pe_type
71
71
  return feature_extractors
72
72
 
73
- def _check_is_fitable(self, feature_extractors):
74
- is_fitable = False
73
+ def _check_is_trainable(self, feature_extractors):
74
+ is_trainable = False
75
75
  for fname, f in feature_extractors.items():
76
76
  if isinstance(f, FeatureExtractor):
77
- is_fitable = f._is_fitable
77
+ is_trainable = f._is_trainable
78
78
  else:
79
79
  f = _get_underlying_func(f)
80
- if isinstance(f, FitableFeature):
81
- is_fitable = True
82
- if is_fitable:
80
+ if isinstance(f, TrainableFeature):
81
+ is_trainable = True
82
+ if is_trainable:
83
83
  break
84
- return is_fitable
84
+ return is_trainable
85
85
 
86
86
  def preprocess(self, *x, **kwargs):
87
87
  return (*x,)
88
88
 
89
- def feature_channel_names(self, ch_names):
90
- return [""]
91
-
92
89
  def __call__(self, *x, _batch_size=None, _ch_names=None):
93
90
  assert _batch_size is not None
94
91
  assert _ch_names is not None
95
- if self._is_fitable:
92
+ if self._is_trainable:
96
93
  super().__call__()
97
94
  results_dict = dict()
98
95
  z = self.preprocess(*x, **self.preprocess_kwargs)
96
+ if not isinstance(z, tuple):
97
+ z = (z,)
99
98
  for fname, f in self.feature_extractors_dict.items():
100
99
  if isinstance(f, FeatureExtractor):
101
100
  r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
@@ -126,28 +125,28 @@ class FeatureExtractor(FitableFeature):
126
125
  results_dict[name] = value
127
126
 
128
127
  def clear(self):
129
- if not self._is_fitable:
128
+ if not self._is_trainable:
130
129
  return
131
130
  for fname, f in self.feature_extractors_dict.items():
132
131
  f = _get_underlying_func(f)
133
- if isinstance(f, FitableFeature):
132
+ if isinstance(f, TrainableFeature):
134
133
  f.clear()
135
134
 
136
135
  def partial_fit(self, *x, y=None):
137
- if not self._is_fitable:
136
+ if not self._is_trainable:
138
137
  return
139
138
  z = self.preprocess(*x, **self.preprocess_kwargs)
140
139
  for fname, f in self.feature_extractors_dict.items():
141
140
  f = _get_underlying_func(f)
142
- if isinstance(f, FitableFeature):
141
+ if isinstance(f, TrainableFeature):
143
142
  f.partial_fit(*z, y=y)
144
143
 
145
144
  def fit(self):
146
- if not self._is_fitable:
145
+ if not self._is_trainable:
147
146
  return
148
147
  for fname, f in self.feature_extractors_dict.items():
149
148
  f = _get_underlying_func(f)
150
- if isinstance(f, FitableFeature):
149
+ if isinstance(f, TrainableFeature):
151
150
  f.fit()
152
151
  super().fit()
153
152
 
@@ -4,6 +4,7 @@ from sklearn.neighbors import KDTree
4
4
 
5
5
  from ..decorators import FeaturePredecessor, univariate_feature
6
6
  from ..extractors import FeatureExtractor
7
+ from .signal import SIGNAL_PREDECESSORS
7
8
 
8
9
  __all__ = [
9
10
  "EntropyFeatureExtractor",
@@ -28,6 +29,7 @@ def _channel_app_samp_entropy_counts(x, m, r, l):
28
29
  return kdtree.query_radius(x_emb, r, count_only=True)
29
30
 
30
31
 
32
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
31
33
  class EntropyFeatureExtractor(FeatureExtractor):
32
34
  def preprocess(self, x, m=2, r=0.2, l=1):
33
35
  rr = r * x.std(axis=-1)
@@ -55,6 +57,7 @@ def complexity_sample_entropy(counts_m, counts_mp1):
55
57
  return -np.log(A / B)
56
58
 
57
59
 
60
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
58
61
  @univariate_feature
59
62
  def complexity_svd_entropy(x, m=10, tau=1):
60
63
  x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
@@ -65,6 +68,7 @@ def complexity_svd_entropy(x, m=10, tau=1):
65
68
  return -np.sum(s * np.log(s), axis=-1)
66
69
 
67
70
 
71
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
68
72
  @univariate_feature
69
73
  @nb.njit(cache=True, fastmath=True)
70
74
  def complexity_lempel_ziv(x, threshold=None):
@@ -4,7 +4,7 @@ import scipy
4
4
  import scipy.linalg
5
5
 
6
6
  from ..decorators import multivariate_feature
7
- from ..extractors import FitableFeature
7
+ from ..extractors import TrainableFeature
8
8
 
9
9
  __all__ = [
10
10
  "CommonSpatialPattern",
@@ -22,7 +22,7 @@ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov):
22
22
 
23
23
 
24
24
  @multivariate_feature
25
- class CommonSpatialPattern(FitableFeature):
25
+ class CommonSpatialPattern(TrainableFeature):
26
26
  def __init__(self):
27
27
  super().__init__()
28
28
 
@@ -2,8 +2,8 @@ import numba as nb
2
2
  import numpy as np
3
3
  from scipy import special
4
4
 
5
- from ..decorators import univariate_feature
6
- from .signal import signal_zero_crossings
5
+ from ..decorators import FeaturePredecessor, univariate_feature
6
+ from .signal import SIGNAL_PREDECESSORS, signal_zero_crossings
7
7
 
8
8
  __all__ = [
9
9
  "dimensionality_higuchi_fractal_dim",
@@ -14,6 +14,7 @@ __all__ = [
14
14
  ]
15
15
 
16
16
 
17
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
17
18
  @univariate_feature
18
19
  @nb.njit(cache=True, fastmath=True)
19
20
  def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
@@ -32,6 +33,7 @@ def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
32
33
  return hfd
33
34
 
34
35
 
36
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
35
37
  @univariate_feature
36
38
  def dimensionality_petrosian_fractal_dim(x):
37
39
  nd = signal_zero_crossings(np.diff(x, axis=-1))
@@ -39,6 +41,7 @@ def dimensionality_petrosian_fractal_dim(x):
39
41
  return log_n / (np.log(nd) + log_n)
40
42
 
41
43
 
44
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
42
45
  @univariate_feature
43
46
  def dimensionality_katz_fractal_dim(x):
44
47
  dists = np.abs(np.diff(x, axis=-1))
@@ -49,7 +52,6 @@ def dimensionality_katz_fractal_dim(x):
49
52
  return log_n / (np.log(d / L) + log_n)
50
53
 
51
54
 
52
- @univariate_feature
53
55
  @nb.njit(cache=True, fastmath=True)
54
56
  def _hurst_exp(x, ns, a, gamma_ratios, log_n):
55
57
  h = np.empty(x.shape[:-1])
@@ -75,6 +77,7 @@ def _hurst_exp(x, ns, a, gamma_ratios, log_n):
75
77
  return h
76
78
 
77
79
 
80
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
78
81
  @univariate_feature
79
82
  def dimensionality_hurst_exp(x):
80
83
  ns = np.unique(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1)).astype(int))
@@ -88,6 +91,7 @@ def dimensionality_hurst_exp(x):
88
91
  return _hurst_exp(x, ns, a, gamma_ratios, log_n)
89
92
 
90
93
 
94
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
91
95
  @univariate_feature
92
96
  @nb.njit(cache=True, fastmath=True)
93
97
  def dimensionality_detrended_fluctuation_analysis(x):
@@ -1,11 +1,13 @@
1
1
  import numbers
2
2
 
3
3
  import numpy as np
4
- from scipy import stats
4
+ from scipy import signal, stats
5
5
 
6
- from ..decorators import univariate_feature
6
+ from ..decorators import FeaturePredecessor, univariate_feature
7
+ from ..extractors import FeatureExtractor
7
8
 
8
9
  __all__ = [
10
+ "HilbertFeatureExtractor",
9
11
  "signal_mean",
10
12
  "signal_variance",
11
13
  "signal_skewness",
@@ -23,51 +25,70 @@ __all__ = [
23
25
  ]
24
26
 
25
27
 
28
+ @FeaturePredecessor(FeatureExtractor)
29
+ class HilbertFeatureExtractor(FeatureExtractor):
30
+ def preprocess(self, x):
31
+ return np.abs(signal.hilbert(x - x.mean(axis=-1, keepdims=True), axis=-1))
32
+
33
+
34
+ SIGNAL_PREDECESSORS = [FeatureExtractor, HilbertFeatureExtractor]
35
+
36
+
37
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
26
38
  @univariate_feature
27
39
  def signal_mean(x):
28
40
  return x.mean(axis=-1)
29
41
 
30
42
 
43
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
31
44
  @univariate_feature
32
45
  def signal_variance(x, **kwargs):
33
46
  return x.var(axis=-1, **kwargs)
34
47
 
35
48
 
49
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
36
50
  @univariate_feature
37
51
  def signal_std(x, **kwargs):
38
52
  return x.std(axis=-1, **kwargs)
39
53
 
40
54
 
55
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
41
56
  @univariate_feature
42
57
  def signal_skewness(x, **kwargs):
43
58
  return stats.skew(x, axis=x.ndim - 1, **kwargs)
44
59
 
45
60
 
61
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
46
62
  @univariate_feature
47
63
  def signal_kurtosis(x, **kwargs):
48
64
  return stats.kurtosis(x, axis=x.ndim - 1, **kwargs)
49
65
 
50
66
 
67
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
51
68
  @univariate_feature
52
69
  def signal_root_mean_square(x):
53
70
  return np.sqrt(np.power(x, 2).mean(axis=-1))
54
71
 
55
72
 
73
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
56
74
  @univariate_feature
57
75
  def signal_peak_to_peak(x, **kwargs):
58
76
  return np.ptp(x, axis=-1, **kwargs)
59
77
 
60
78
 
79
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
61
80
  @univariate_feature
62
81
  def signal_quantile(x, q: numbers.Number = 0.5, **kwargs):
63
82
  return np.quantile(x, q=q, axis=-1, **kwargs)
64
83
 
65
84
 
85
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
66
86
  @univariate_feature
67
87
  def signal_line_length(x):
68
88
  return np.abs(np.diff(x, axis=-1)).mean(axis=-1)
69
89
 
70
90
 
91
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
71
92
  @univariate_feature
72
93
  def signal_zero_crossings(x, threshold=1e-15):
73
94
  zero_ind = np.logical_and(x > -threshold, x < threshold)
@@ -78,16 +99,21 @@ def signal_zero_crossings(x, threshold=1e-15):
78
99
  return zero_cross
79
100
 
80
101
 
102
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
81
103
  @univariate_feature
82
104
  def signal_hjorth_mobility(x):
83
105
  return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1)
84
106
 
85
107
 
108
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
86
109
  @univariate_feature
87
110
  def signal_hjorth_complexity(x):
88
- return np.diff(x, 2, axis=-1).std(axis=-1) / x.std(axis=-1)
111
+ return (np.diff(x, 2, axis=-1).std(axis=-1) * x.std(axis=-1)) / np.diff(
112
+ x, axis=-1
113
+ ).var(axis=-1)
89
114
 
90
115
 
116
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
91
117
  @univariate_feature
92
118
  def signal_decorrelation_time(x, fs=1):
93
119
  f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1)
@@ -0,0 +1,48 @@
1
+ import inspect
2
+ from collections.abc import Callable
3
+
4
+ from . import extractors, feature_bank
5
+ from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
6
+
7
+
8
+ def get_feature_predecessors(feature_or_extractor: Callable):
9
+ current = _get_underlying_func(feature_or_extractor)
10
+ if current is FeatureExtractor:
11
+ return [current]
12
+ predecessor = getattr(current, "parent_extractor_type", [FeatureExtractor])
13
+ if len(predecessor) == 1:
14
+ return [current, *get_feature_predecessors(predecessor[0])]
15
+ else:
16
+ predecessors = [get_feature_predecessors(pred) for pred in predecessor]
17
+ for i in range(len(predecessors)):
18
+ if isinstance(predecessors[i], list) and len(predecessors[i]) == 1:
19
+ predecessors[i] = predecessors[i][0]
20
+ return [current, tuple(predecessors)]
21
+
22
+
23
+ def get_feature_kind(feature: Callable):
24
+ return _get_underlying_func(feature).feature_kind
25
+
26
+
27
+ def get_all_features():
28
+ def isfeature(x):
29
+ return hasattr(_get_underlying_func(x), "feature_kind")
30
+
31
+ return inspect.getmembers(feature_bank, isfeature)
32
+
33
+
34
+ def get_all_feature_extractors():
35
+ def isfeatureextractor(x):
36
+ return inspect.isclass(x) and issubclass(x, FeatureExtractor)
37
+
38
+ return [
39
+ ("FeatureExtractor", FeatureExtractor),
40
+ *inspect.getmembers(feature_bank, isfeatureextractor),
41
+ ]
42
+
43
+
44
+ def get_all_feature_kinds():
45
+ def isfeaturekind(x):
46
+ return inspect.isclass(x) and issubclass(x, MultivariateFeature)
47
+
48
+ return inspect.getmembers(extractors, isfeaturekind)
@@ -1,10 +1,8 @@
1
- """
2
- Convenience functions for storing and loading of features datasets.
1
+ """Convenience functions for storing and loading of features datasets.
3
2
 
4
3
  see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
5
4
  """
6
5
 
7
- import json
8
6
  from pathlib import Path
9
7
 
10
8
  import pandas as pd
@@ -34,6 +32,7 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
34
32
  Returns
35
33
  -------
36
34
  concat_dataset: FeaturesConcatDataset of FeaturesDatasets
35
+
37
36
  """
38
37
  # Make sure we always work with a pathlib.Path
39
38
  path = Path(path)
eegdash/features/utils.py CHANGED
@@ -102,7 +102,7 @@ def fit_feature_extractors(
102
102
  features = dict(enumerate(features))
103
103
  if not isinstance(features, FeatureExtractor):
104
104
  features = FeatureExtractor(features)
105
- if not features._is_fitable:
105
+ if not features._is_trainable:
106
106
  return features
107
107
  features.clear()
108
108
  concat_dl = DataLoader(