eegdash 0.0.9__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/dataset.py ADDED
@@ -0,0 +1,60 @@
1
+ from .api import EEGDashDataset
2
+
3
+
4
+ class EEGChallengeDataset(EEGDashDataset):
5
+ def __init__(
6
+ self,
7
+ release: str = "R5",
8
+ cache_dir: str = ".eegdash_cache",
9
+ s3_bucket: str | None = "s3://nmdatasets/NeurIPS25/R5_L100",
10
+ **kwargs,
11
+ ):
12
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
13
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
14
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
15
+
16
+ Parameters
17
+ ----------
18
+ query : dict | None
19
+ Optionally a dictionary that specifies the query to be executed; see
20
+ EEGDash.find() for details on the query format.
21
+ data_dir : str | list[str] | None
22
+ Optionally a string or a list of strings specifying one or more local
23
+ BIDS dataset directories from which to load the EEG data files. Exactly one
24
+ of query or data_dir must be provided.
25
+ dataset : str | list[str] | None
26
+ If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
27
+ description_fields : list[str]
28
+ A list of fields to be extracted from the dataset records
29
+ and included in the returned data description(s). Examples are typical
30
+ subject metadata fields such as "subject", "session", "run", "task", etc.;
31
+ see also data_config.description_fields for the default set of fields.
32
+ cache_dir : str
33
+ A directory where the dataset will be cached locally.
34
+ s3_bucket : str | None
35
+ An optional S3 bucket URI to use instead of the
36
+ default OpenNeuro bucket for loading data files.
37
+ kwargs : dict
38
+ Additional keyword arguments to be passed to the EEGDashBaseDataset
39
+ constructor.
40
+
41
+ """
42
+ dsnumber_release_map = {
43
+ "R11": "ds005516",
44
+ "R10": "ds005515",
45
+ "R9": "ds005514",
46
+ "R8": "ds005512",
47
+ "R7": "ds005511",
48
+ "R6": "ds005510",
49
+ "R4": "ds005508",
50
+ "R5": "ds005509",
51
+ "R3": "ds005507",
52
+ "R2": "ds005506",
53
+ "R1": "ds005505",
54
+ }
55
+ super().__init__(
56
+ query={"dataset": dsnumber_release_map[release]},
57
+ cache_dir=cache_dir,
58
+ s3_bucket=s3_bucket,
59
+ **kwargs,
60
+ )
@@ -1,25 +1,53 @@
1
- # Features datasets
2
- from .datasets import FeaturesDataset, FeaturesConcatDataset
3
- from .serialization import load_features_concat_dataset
4
-
5
- # Feature extraction
1
+ from .datasets import FeaturesConcatDataset, FeaturesDataset
2
+ from .decorators import (
3
+ FeatureKind,
4
+ FeaturePredecessor,
5
+ bivariate_feature,
6
+ multivariate_feature,
7
+ univariate_feature,
8
+ )
6
9
  from .extractors import (
7
- FeatureExtractor,
8
- FitableFeature,
9
- UnivariateFeature,
10
10
  BivariateFeature,
11
11
  DirectedBivariateFeature,
12
+ FeatureExtractor,
12
13
  MultivariateFeature,
14
+ TrainableFeature,
15
+ UnivariateFeature,
13
16
  )
14
- from .decorators import (
15
- FeaturePredecessor,
16
- FeatureKind,
17
- univariate_feature,
18
- bivariate_feature,
19
- directed_bivariate_feature,
20
- multivariate_feature,
17
+ from .feature_bank import * # noqa: F401
18
+ from .inspect import (
19
+ get_all_feature_extractors,
20
+ get_all_feature_kinds,
21
+ get_all_features,
22
+ get_feature_kind,
23
+ get_feature_predecessors,
24
+ )
25
+ from .serialization import load_features_concat_dataset
26
+ from .utils import (
27
+ extract_features,
28
+ fit_feature_extractors,
21
29
  )
22
- from .utils import extract_features, fit_feature_extractors
23
30
 
24
- # Features:
25
- from .feature_bank import *
31
+ __all__ = [
32
+ "FeaturesConcatDataset",
33
+ "FeaturesDataset",
34
+ "FeatureKind",
35
+ "FeaturePredecessor",
36
+ "bivariate_feature",
37
+ "multivariate_feature",
38
+ "univariate_feature",
39
+ "BivariateFeature",
40
+ "DirectedBivariateFeature",
41
+ "FeatureExtractor",
42
+ "MultivariateFeature",
43
+ "TrainableFeature",
44
+ "UnivariateFeature",
45
+ "get_all_feature_extractors",
46
+ "get_all_feature_kinds",
47
+ "get_all_features",
48
+ "get_feature_kind",
49
+ "get_feature_predecessors",
50
+ "load_features_concat_dataset",
51
+ "extract_features",
52
+ "fit_feature_extractors",
53
+ ]
@@ -1,16 +1,19 @@
1
1
  from __future__ import annotations
2
- import os
2
+
3
3
  import json
4
+ import os
4
5
  import shutil
5
6
  import warnings
6
- from typing import Dict, no_type_check
7
- from collections.abc import Callable, Iterable
7
+ from collections.abc import Callable
8
+ from typing import Dict, List
9
+
8
10
  import numpy as np
9
11
  import pandas as pd
10
12
  from joblib import Parallel, delayed
13
+
11
14
  from braindecode.datasets.base import (
12
- EEGWindowsDataset,
13
15
  BaseConcatDataset,
16
+ EEGWindowsDataset,
14
17
  _create_description,
15
18
  )
16
19
 
@@ -30,6 +33,7 @@ class FeaturesDataset(EEGWindowsDataset):
30
33
  Holds additional description about the continuous signal / subject.
31
34
  transform : callable | None
32
35
  On-the-fly transform applied to the example before it is returned.
36
+
33
37
  """
34
38
 
35
39
  def __init__(
@@ -92,10 +96,12 @@ def _compute_stats(
92
96
  return tuple(res)
93
97
 
94
98
 
95
- def _pooled_var(counts, means, variances, ddof):
99
+ def _pooled_var(counts, means, variances, ddof, ddof_in=None):
100
+ if ddof_in is None:
101
+ ddof_in = ddof
96
102
  count = counts.sum(axis=0)
97
103
  mean = np.sum((counts / count) * means, axis=0)
98
- var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
104
+ var = np.sum(((counts - ddof_in) / (count - ddof)) * variances, axis=0)
99
105
  var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
100
106
  var[:] -= (count / (count - ddof)) * (mean**2)
101
107
  var[:] = var.clip(min=0)
@@ -153,6 +159,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
153
159
  splits : dict
154
160
  A dictionary with the name of the split (a string) as key and the
155
161
  dataset as value.
162
+
156
163
  """
157
164
  if isinstance(by, str):
158
165
  split_ids = {
@@ -184,6 +191,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
184
191
  DataFrame containing as many rows as there are windows in the
185
192
  BaseConcatDataset, with the metadata and description information
186
193
  for each window.
194
+
187
195
  """
188
196
  if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
189
197
  raise TypeError(
@@ -235,6 +243,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
235
243
  concat. This is useful in the setting of very large datasets, where
236
244
  one dataset has to be processed and saved at a time to account for
237
245
  its original position.
246
+
238
247
  """
239
248
  if len(self.datasets) == 0:
240
249
  raise ValueError("Expect at least one dataset")
@@ -320,25 +329,53 @@ class FeaturesConcatDataset(BaseConcatDataset):
320
329
  json.dump(kwargs, f)
321
330
 
322
331
  def to_dataframe(
323
- self, include_metadata=False, include_target=False, include_crop_inds=False
332
+ self,
333
+ include_metadata: bool | str | List[str] = False,
334
+ include_target: bool = False,
335
+ include_crop_inds: bool = False,
324
336
  ):
325
- if include_metadata or (include_target and include_crop_inds):
337
+ if (
338
+ not isinstance(include_metadata, bool)
339
+ or include_metadata
340
+ or include_crop_inds
341
+ ):
342
+ include_dataset = False
343
+ if isinstance(include_metadata, bool) and include_metadata:
344
+ include_dataset = True
345
+ cols = self.datasets[0].metadata.columns
346
+ else:
347
+ cols = include_metadata
348
+ if isinstance(cols, bool) and not cols:
349
+ cols = []
350
+ elif isinstance(cols, str):
351
+ cols = [cols]
352
+ cols = set(cols)
353
+ if include_crop_inds:
354
+ cols = {
355
+ "i_dataset",
356
+ "i_window_in_trial",
357
+ "i_start_in_trial",
358
+ "i_stop_in_trial",
359
+ *cols,
360
+ }
361
+ if include_target:
362
+ cols.add("target")
363
+ cols = list(cols)
364
+ include_dataset = "i_dataset" in cols
365
+ if include_dataset:
366
+ cols.remove("i_dataset")
326
367
  dataframes = [
327
- ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
368
+ ds.metadata[cols].join(ds.features, how="right", lsuffix="_metadata")
328
369
  for ds in self.datasets
329
370
  ]
371
+ if include_dataset:
372
+ for i, df in enumerate(dataframes):
373
+ df.insert(loc=0, column="i_dataset", value=i)
330
374
  elif include_target:
331
375
  dataframes = [
332
376
  ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
333
377
  for ds in self.datasets
334
378
  ]
335
- elif include_crop_inds:
336
- dataframes = [
337
- ds.metadata.drop("target", axis="columns").join(
338
- ds.features, how="right", lsuffix="_metadata"
339
- )
340
- for ds in self.datasets
341
- ]
342
379
  else:
343
380
  dataframes = [ds.features for ds in self.datasets]
344
381
  return pd.concat(dataframes, axis=0, ignore_index=True)
@@ -374,7 +411,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
374
411
  return_count=True,
375
412
  return_mean=True,
376
413
  return_var=True,
377
- ddof=ddof,
414
+ ddof=0,
378
415
  numeric_only=numeric_only,
379
416
  )
380
417
  for ds in self.datasets
@@ -384,11 +421,13 @@ class FeaturesConcatDataset(BaseConcatDataset):
384
421
  np.array([s[1] for s in stats]),
385
422
  np.array([s[2] for s in stats]),
386
423
  )
387
- _, _, var = _pooled_var(counts, means, variances, ddof)
424
+ _, _, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
388
425
  return pd.Series(var, index=self._numeric_columns())
389
426
 
390
- def std(self, ddof=1, numeric_only=False, n_jobs=1):
391
- return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
427
+ def std(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
428
+ return np.sqrt(
429
+ self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs) + eps
430
+ )
392
431
 
393
432
  def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
394
433
  stats = Parallel(n_jobs)(
@@ -397,7 +436,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
397
436
  return_count=True,
398
437
  return_mean=True,
399
438
  return_var=True,
400
- ddof=ddof,
439
+ ddof=0,
401
440
  numeric_only=numeric_only,
402
441
  )
403
442
  for ds in self.datasets
@@ -407,8 +446,8 @@ class FeaturesConcatDataset(BaseConcatDataset):
407
446
  np.array([s[1] for s in stats]),
408
447
  np.array([s[2] for s in stats]),
409
448
  )
410
- _, mean, var = _pooled_var(counts, means, variances, ddof)
411
- std = np.sqrt(var) + eps
449
+ _, mean, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
450
+ std = np.sqrt(var + eps)
412
451
  for ds in self.datasets:
413
452
  ds.features = (ds.features - mean) / std
414
453
 
@@ -1,14 +1,14 @@
1
- from typing import List, Type
2
1
  from collections.abc import Callable
2
+ from typing import List, Type
3
3
 
4
4
  from .extractors import (
5
- FeatureExtractor,
6
- UnivariateFeature,
7
5
  BivariateFeature,
8
6
  DirectedBivariateFeature,
7
+ FeatureExtractor,
9
8
  MultivariateFeature,
9
+ UnivariateFeature,
10
+ _get_underlying_func,
10
11
  )
11
- from .extractors import _get_underlying_func
12
12
 
13
13
 
14
14
  class FeaturePredecessor:
@@ -38,6 +38,14 @@ class FeatureKind:
38
38
 
39
39
  # Syntax sugar
40
40
  univariate_feature = FeatureKind(UnivariateFeature())
41
- bivariate_feature = FeatureKind(BivariateFeature())
42
- directed_bivariate_feature = FeatureKind(DirectedBivariateFeature())
41
+
42
+
43
+ def bivariate_feature(func, directed=False):
44
+ if directed:
45
+ kind = DirectedBivariateFeature()
46
+ else:
47
+ kind = BivariateFeature()
48
+ return FeatureKind(kind)(func)
49
+
50
+
43
51
  multivariate_feature = FeatureKind(MultivariateFeature())
@@ -1,7 +1,8 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict
3
2
  from collections.abc import Callable
4
3
  from functools import partial
4
+ from typing import Dict
5
+
5
6
  import numpy as np
6
7
  from numba.core.dispatcher import Dispatcher
7
8
 
@@ -15,9 +16,9 @@ def _get_underlying_func(func):
15
16
  return f
16
17
 
17
18
 
18
- class FitableFeature(ABC):
19
+ class TrainableFeature(ABC):
19
20
  def __init__(self):
20
- self._is_fitted = False
21
+ self._is_trained = False
21
22
  self.clear()
22
23
 
23
24
  @abstractmethod
@@ -38,12 +39,12 @@ class FitableFeature(ABC):
38
39
  )
39
40
 
40
41
 
41
- class FeatureExtractor(FitableFeature):
42
+ class FeatureExtractor(TrainableFeature):
42
43
  def __init__(
43
44
  self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
44
45
  ):
45
46
  self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
46
- self._is_fitable = self._check_is_fitable(feature_extractors)
47
+ self._is_trainable = self._check_is_trainable(feature_extractors)
47
48
  super().__init__()
48
49
 
49
50
  # bypassing FeaturePredecessor to avoid circular import
@@ -69,32 +70,31 @@ class FeatureExtractor(FitableFeature):
69
70
  assert type(self) in pe_type
70
71
  return feature_extractors
71
72
 
72
- def _check_is_fitable(self, feature_extractors):
73
- is_fitable = False
73
+ def _check_is_trainable(self, feature_extractors):
74
+ is_trainable = False
74
75
  for fname, f in feature_extractors.items():
75
76
  if isinstance(f, FeatureExtractor):
76
- is_fitable = f._is_fitable
77
+ is_trainable = f._is_trainable
77
78
  else:
78
79
  f = _get_underlying_func(f)
79
- if isinstance(f, FitableFeature):
80
- is_fitable = True
81
- if is_fitable:
80
+ if isinstance(f, TrainableFeature):
81
+ is_trainable = True
82
+ if is_trainable:
82
83
  break
83
- return is_fitable
84
+ return is_trainable
84
85
 
85
86
  def preprocess(self, *x, **kwargs):
86
87
  return (*x,)
87
88
 
88
- def feature_channel_names(self, ch_names):
89
- return [""]
90
-
91
89
  def __call__(self, *x, _batch_size=None, _ch_names=None):
92
90
  assert _batch_size is not None
93
91
  assert _ch_names is not None
94
- if self._is_fitable:
92
+ if self._is_trainable:
95
93
  super().__call__()
96
94
  results_dict = dict()
97
95
  z = self.preprocess(*x, **self.preprocess_kwargs)
96
+ if not isinstance(z, tuple):
97
+ z = (z,)
98
98
  for fname, f in self.feature_extractors_dict.items():
99
99
  if isinstance(f, FeatureExtractor):
100
100
  r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
@@ -125,28 +125,28 @@ class FeatureExtractor(FitableFeature):
125
125
  results_dict[name] = value
126
126
 
127
127
  def clear(self):
128
- if not self._is_fitable:
128
+ if not self._is_trainable:
129
129
  return
130
130
  for fname, f in self.feature_extractors_dict.items():
131
131
  f = _get_underlying_func(f)
132
- if isinstance(f, FitableFeature):
132
+ if isinstance(f, TrainableFeature):
133
133
  f.clear()
134
134
 
135
135
  def partial_fit(self, *x, y=None):
136
- if not self._is_fitable:
136
+ if not self._is_trainable:
137
137
  return
138
138
  z = self.preprocess(*x, **self.preprocess_kwargs)
139
139
  for fname, f in self.feature_extractors_dict.items():
140
140
  f = _get_underlying_func(f)
141
- if isinstance(f, FitableFeature):
141
+ if isinstance(f, TrainableFeature):
142
142
  f.partial_fit(*z, y=y)
143
143
 
144
144
  def fit(self):
145
- if not self._is_fitable:
145
+ if not self._is_trainable:
146
146
  return
147
147
  for fname, f in self.feature_extractors_dict.items():
148
148
  f = _get_underlying_func(f)
149
- if isinstance(f, FitableFeature):
149
+ if isinstance(f, TrainableFeature):
150
150
  f.fit()
151
151
  super().fit()
152
152
 
@@ -1,6 +1,6 @@
1
- from .signal import *
2
- from .spectral import *
3
1
  from .complexity import *
4
- from .dimensionality import *
5
2
  from .connectivity import *
6
3
  from .csp import *
4
+ from .dimensionality import *
5
+ from .signal import *
6
+ from .spectral import *
@@ -1,10 +1,10 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  from sklearn.neighbors import KDTree
4
4
 
5
- from ..extractors import FeatureExtractor
6
5
  from ..decorators import FeaturePredecessor, univariate_feature
7
-
6
+ from ..extractors import FeatureExtractor
7
+ from .signal import SIGNAL_PREDECESSORS
8
8
 
9
9
  __all__ = [
10
10
  "EntropyFeatureExtractor",
@@ -29,6 +29,7 @@ def _channel_app_samp_entropy_counts(x, m, r, l):
29
29
  return kdtree.query_radius(x_emb, r, count_only=True)
30
30
 
31
31
 
32
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
32
33
  class EntropyFeatureExtractor(FeatureExtractor):
33
34
  def preprocess(self, x, m=2, r=0.2, l=1):
34
35
  rr = r * x.std(axis=-1)
@@ -56,6 +57,7 @@ def complexity_sample_entropy(counts_m, counts_mp1):
56
57
  return -np.log(A / B)
57
58
 
58
59
 
60
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
59
61
  @univariate_feature
60
62
  def complexity_svd_entropy(x, m=10, tau=1):
61
63
  x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
@@ -66,6 +68,7 @@ def complexity_svd_entropy(x, m=10, tau=1):
66
68
  return -np.sum(s * np.log(s), axis=-1)
67
69
 
68
70
 
71
+ @FeaturePredecessor(*SIGNAL_PREDECESSORS)
69
72
  @univariate_feature
70
73
  @nb.njit(cache=True, fastmath=True)
71
74
  def complexity_lempel_ziv(x, threshold=None):
@@ -1,10 +1,11 @@
1
1
  from itertools import chain
2
+
2
3
  import numpy as np
3
4
  from scipy.signal import csd
4
5
 
5
- from ..extractors import FeatureExtractor, BivariateFeature
6
6
  from ..decorators import FeaturePredecessor, bivariate_feature
7
-
7
+ from ..extractors import BivariateFeature, FeatureExtractor
8
+ from . import utils
8
9
 
9
10
  __all__ = [
10
11
  "CoherenceFeatureExtractor",
@@ -18,82 +19,41 @@ class CoherenceFeatureExtractor(FeatureExtractor):
18
19
  def preprocess(self, x, **kwargs):
19
20
  f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
20
21
  f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
22
+ assert "fs" in kwargs and "nperseg" in kwargs
21
23
  kwargs["axis"] = -1
22
24
  n = x.shape[1]
23
25
  idx_x, idx_y = BivariateFeature.get_pair_iterators(n)
24
26
  ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y))
25
27
  f, s = csd(x[:, ix], x[:, iy], **kwargs)
26
- if f_min is not None or f_max is not None:
27
- f_min_idx = f > f_min if f_min is not None else True
28
- f_max_idx = f < f_max if f_max is not None else True
29
- idx = np.logical_and(f_min_idx, f_max_idx)
30
- f = f[idx]
31
- s = s[..., idx]
32
- sx, sxy = np.split(s, [n], axis=1)
33
- sxx, syy = sx[:, idx_x].real, sx[:, idx_y].real
28
+ f_min, f_max = utils.get_valid_freq_band(
29
+ kwargs["fs"], x.shape[-1], f_min, f_max
30
+ )
31
+ f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max)
32
+ p, sxy = np.split(s, [n], axis=1)
33
+ sxx, syy = p[:, idx_x].real, p[:, idx_y].real
34
34
  c = sxy / np.sqrt(sxx * syy)
35
35
  return f, c
36
36
 
37
37
 
38
- def _avg_over_bands(f, x, bands):
39
- bands_avg = dict()
40
- for k, v in bands.items():
41
- assert isinstance(k, str)
42
- assert isinstance(v, tuple)
43
- assert len(v) == 2
44
- assert v[0] < v[1]
45
- mask = np.logical_and(f > v[0], f < v[1])
46
- avg = x[..., mask].mean(axis=-1)
47
- bands_avg[k] = avg
48
- return bands_avg
49
-
50
-
51
38
  @FeaturePredecessor(CoherenceFeatureExtractor)
52
39
  @bivariate_feature
53
- def connectivity_magnitude_square_coherence(
54
- f,
55
- c,
56
- bands={
57
- "delta": (1, 4.5),
58
- "theta": (4.5, 8),
59
- "alpha": (8, 12),
60
- "beta": (12, 30),
61
- },
62
- ):
40
+ def connectivity_magnitude_square_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
63
41
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
64
42
  coher = c.real**2 + c.imag**2
65
- return _avg_over_bands(f, coher, bands)
43
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
66
44
 
67
45
 
68
46
  @FeaturePredecessor(CoherenceFeatureExtractor)
69
47
  @bivariate_feature
70
- def connectivity_imaginary_coherence(
71
- f,
72
- c,
73
- bands={
74
- "delta": (1, 4.5),
75
- "theta": (4.5, 8),
76
- "alpha": (8, 12),
77
- "beta": (12, 30),
78
- },
79
- ):
48
+ def connectivity_imaginary_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
80
49
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
81
50
  coher = c.imag
82
- return _avg_over_bands(f, coher, bands)
51
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
83
52
 
84
53
 
85
54
  @FeaturePredecessor(CoherenceFeatureExtractor)
86
55
  @bivariate_feature
87
- def connectivity_lagged_coherence(
88
- f,
89
- c,
90
- bands={
91
- "delta": (1, 4.5),
92
- "theta": (4.5, 8),
93
- "alpha": (8, 12),
94
- "beta": (12, 30),
95
- },
96
- ):
56
+ def connectivity_lagged_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
97
57
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
98
58
  coher = c.imag / np.sqrt(1 - c.real)
99
- return _avg_over_bands(f, coher, bands)
59
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
@@ -1,11 +1,10 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  import scipy
4
4
  import scipy.linalg
5
5
 
6
- from ..extractors import FitableFeature
7
6
  from ..decorators import multivariate_feature
8
-
7
+ from ..extractors import TrainableFeature
9
8
 
10
9
  __all__ = [
11
10
  "CommonSpatialPattern",
@@ -23,7 +22,7 @@ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov):
23
22
 
24
23
 
25
24
  @multivariate_feature
26
- class CommonSpatialPattern(FitableFeature):
25
+ class CommonSpatialPattern(TrainableFeature):
27
26
  def __init__(self):
28
27
  super().__init__()
29
28