eegdash 0.0.9__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,16 +1,19 @@
1
1
  from __future__ import annotations
2
- import os
2
+
3
3
  import json
4
+ import os
4
5
  import shutil
5
6
  import warnings
6
- from typing import Dict, no_type_check
7
7
  from collections.abc import Callable, Iterable
8
+ from typing import Dict, no_type_check
9
+
8
10
  import numpy as np
9
11
  import pandas as pd
10
12
  from joblib import Parallel, delayed
13
+
11
14
  from braindecode.datasets.base import (
12
- EEGWindowsDataset,
13
15
  BaseConcatDataset,
16
+ EEGWindowsDataset,
14
17
  _create_description,
15
18
  )
16
19
 
@@ -1,14 +1,14 @@
1
- from typing import List, Type
2
1
  from collections.abc import Callable
2
+ from typing import List, Type
3
3
 
4
4
  from .extractors import (
5
- FeatureExtractor,
6
- UnivariateFeature,
7
5
  BivariateFeature,
8
6
  DirectedBivariateFeature,
7
+ FeatureExtractor,
9
8
  MultivariateFeature,
9
+ UnivariateFeature,
10
+ _get_underlying_func,
10
11
  )
11
- from .extractors import _get_underlying_func
12
12
 
13
13
 
14
14
  class FeaturePredecessor:
@@ -1,7 +1,8 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict
3
2
  from collections.abc import Callable
4
3
  from functools import partial
4
+ from typing import Dict
5
+
5
6
  import numpy as np
6
7
  from numba.core.dispatcher import Dispatcher
7
8
 
@@ -1,6 +1,6 @@
1
- from .signal import *
2
- from .spectral import *
3
1
  from .complexity import *
4
- from .dimensionality import *
5
2
  from .connectivity import *
6
3
  from .csp import *
4
+ from .dimensionality import *
5
+ from .signal import *
6
+ from .spectral import *
@@ -1,10 +1,9 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  from sklearn.neighbors import KDTree
4
4
 
5
- from ..extractors import FeatureExtractor
6
5
  from ..decorators import FeaturePredecessor, univariate_feature
7
-
6
+ from ..extractors import FeatureExtractor
8
7
 
9
8
  __all__ = [
10
9
  "EntropyFeatureExtractor",
@@ -1,10 +1,11 @@
1
1
  from itertools import chain
2
+
2
3
  import numpy as np
3
4
  from scipy.signal import csd
4
5
 
5
- from ..extractors import FeatureExtractor, BivariateFeature
6
6
  from ..decorators import FeaturePredecessor, bivariate_feature
7
-
7
+ from ..extractors import BivariateFeature, FeatureExtractor
8
+ from . import utils
8
9
 
9
10
  __all__ = [
10
11
  "CoherenceFeatureExtractor",
@@ -18,82 +19,41 @@ class CoherenceFeatureExtractor(FeatureExtractor):
18
19
  def preprocess(self, x, **kwargs):
19
20
  f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
20
21
  f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
22
+ assert "fs" in kwargs and "nperseg" in kwargs
21
23
  kwargs["axis"] = -1
22
24
  n = x.shape[1]
23
25
  idx_x, idx_y = BivariateFeature.get_pair_iterators(n)
24
26
  ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y))
25
27
  f, s = csd(x[:, ix], x[:, iy], **kwargs)
26
- if f_min is not None or f_max is not None:
27
- f_min_idx = f > f_min if f_min is not None else True
28
- f_max_idx = f < f_max if f_max is not None else True
29
- idx = np.logical_and(f_min_idx, f_max_idx)
30
- f = f[idx]
31
- s = s[..., idx]
32
- sx, sxy = np.split(s, [n], axis=1)
33
- sxx, syy = sx[:, idx_x].real, sx[:, idx_y].real
28
+ f_min, f_max = utils.get_valid_freq_band(
29
+ kwargs["fs"], x.shape[-1], f_min, f_max
30
+ )
31
+ f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max)
32
+ p, sxy = np.split(s, [n], axis=1)
33
+ sxx, syy = p[:, idx_x].real, p[:, idx_y].real
34
34
  c = sxy / np.sqrt(sxx * syy)
35
35
  return f, c
36
36
 
37
37
 
38
- def _avg_over_bands(f, x, bands):
39
- bands_avg = dict()
40
- for k, v in bands.items():
41
- assert isinstance(k, str)
42
- assert isinstance(v, tuple)
43
- assert len(v) == 2
44
- assert v[0] < v[1]
45
- mask = np.logical_and(f > v[0], f < v[1])
46
- avg = x[..., mask].mean(axis=-1)
47
- bands_avg[k] = avg
48
- return bands_avg
49
-
50
-
51
38
  @FeaturePredecessor(CoherenceFeatureExtractor)
52
39
  @bivariate_feature
53
- def connectivity_magnitude_square_coherence(
54
- f,
55
- c,
56
- bands={
57
- "delta": (1, 4.5),
58
- "theta": (4.5, 8),
59
- "alpha": (8, 12),
60
- "beta": (12, 30),
61
- },
62
- ):
40
+ def connectivity_magnitude_square_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
63
41
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
64
42
  coher = c.real**2 + c.imag**2
65
- return _avg_over_bands(f, coher, bands)
43
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
66
44
 
67
45
 
68
46
  @FeaturePredecessor(CoherenceFeatureExtractor)
69
47
  @bivariate_feature
70
- def connectivity_imaginary_coherence(
71
- f,
72
- c,
73
- bands={
74
- "delta": (1, 4.5),
75
- "theta": (4.5, 8),
76
- "alpha": (8, 12),
77
- "beta": (12, 30),
78
- },
79
- ):
48
+ def connectivity_imaginary_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
80
49
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
81
50
  coher = c.imag
82
- return _avg_over_bands(f, coher, bands)
51
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
83
52
 
84
53
 
85
54
  @FeaturePredecessor(CoherenceFeatureExtractor)
86
55
  @bivariate_feature
87
- def connectivity_lagged_coherence(
88
- f,
89
- c,
90
- bands={
91
- "delta": (1, 4.5),
92
- "theta": (4.5, 8),
93
- "alpha": (8, 12),
94
- "beta": (12, 30),
95
- },
96
- ):
56
+ def connectivity_lagged_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
97
57
  # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
98
58
  coher = c.imag / np.sqrt(1 - c.real)
99
- return _avg_over_bands(f, coher, bands)
59
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
@@ -1,11 +1,10 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  import scipy
4
4
  import scipy.linalg
5
5
 
6
- from ..extractors import FitableFeature
7
6
  from ..decorators import multivariate_feature
8
-
7
+ from ..extractors import FitableFeature
9
8
 
10
9
  __all__ = [
11
10
  "CommonSpatialPattern",
@@ -1,11 +1,10 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  from scipy import special
4
4
 
5
5
  from ..decorators import univariate_feature
6
6
  from .signal import signal_zero_crossings
7
7
 
8
-
9
8
  __all__ = [
10
9
  "dimensionality_higuchi_fractal_dim",
11
10
  "dimensionality_petrosian_fractal_dim",
@@ -1,10 +1,10 @@
1
1
  import numbers
2
+
2
3
  import numpy as np
3
4
  from scipy import stats
4
5
 
5
6
  from ..decorators import univariate_feature
6
7
 
7
-
8
8
  __all__ = [
9
9
  "signal_mean",
10
10
  "signal_variance",
@@ -1,10 +1,10 @@
1
- import numpy as np
2
1
  import numba as nb
2
+ import numpy as np
3
3
  from scipy.signal import welch
4
4
 
5
- from ..extractors import FeatureExtractor
6
5
  from ..decorators import FeaturePredecessor, univariate_feature
7
-
6
+ from ..extractors import FeatureExtractor
7
+ from . import utils
8
8
 
9
9
  __all__ = [
10
10
  "SpectralFeatureExtractor",
@@ -26,14 +26,13 @@ class SpectralFeatureExtractor(FeatureExtractor):
26
26
  def preprocess(self, x, **kwargs):
27
27
  f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
28
28
  f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
29
+ assert "fs" in kwargs
29
30
  kwargs["axis"] = -1
30
31
  f, p = welch(x, **kwargs)
31
- if f_min is not None or f_max is not None:
32
- f_min_idx = f > f_min if f_min is not None else True
33
- f_max_idx = f < f_max if f_max is not None else True
34
- idx = np.logical_and(f_min_idx, f_max_idx)
35
- f = f[idx]
36
- p = p[..., idx]
32
+ f_min, f_max = utils.get_valid_freq_band(
33
+ kwargs["fs"], x.shape[-1], f_min, f_max
34
+ )
35
+ f, p = utils.slice_freq_band(f, p, f_min=f_min, f_max=f_max)
37
36
  return f, p
38
37
 
39
38
 
@@ -113,22 +112,5 @@ def spectral_slope(f, p):
113
112
  DBSpectralFeatureExtractor,
114
113
  )
115
114
  @univariate_feature
116
- def spectral_bands_power(
117
- f,
118
- p,
119
- bands={
120
- "delta": (1, 4.5),
121
- "theta": (4.5, 8),
122
- "alpha": (8, 12),
123
- "beta": (12, 30),
124
- },
125
- ):
126
- bands_power = dict()
127
- for k, v in bands.items():
128
- assert isinstance(k, str)
129
- assert isinstance(v, tuple)
130
- assert len(v) == 2
131
- mask = np.logical_and(f > v[0], f < v[1])
132
- power = p[..., mask].sum(axis=-1)
133
- bands_power[k] = power
134
- return bands_power
115
+ def spectral_bands_power(f, p, bands=utils.DEFAULT_FREQ_BANDS):
116
+ return utils.reduce_freq_bands(f, p, bands, np.sum)
@@ -0,0 +1,48 @@
1
+ import numpy as np
2
+
3
+ DEFAULT_FREQ_BANDS = {
4
+ "delta": (1, 4.5),
5
+ "theta": (4.5, 8),
6
+ "alpha": (8, 12),
7
+ "beta": (12, 30),
8
+ }
9
+
10
+
11
+ def get_valid_freq_band(fs, n, f_min=None, f_max=None):
12
+ f0 = 2 * fs / n
13
+ f1 = fs / 2
14
+ if f_min is None:
15
+ f_min = f0
16
+ else:
17
+ assert f_min >= f0
18
+ if f_max is None:
19
+ f_max = f1
20
+ else:
21
+ assert f_max <= f1
22
+ return f_min, f_max
23
+
24
+
25
+ def slice_freq_band(f, *x, f_min=None, f_max=None):
26
+ if f_min is None and f_max is None:
27
+ return f, *x
28
+ else:
29
+ f_min_idx = f >= f_min if f_min is not None else True
30
+ f_max_idx = f <= f_max if f_max is not None else True
31
+ idx = np.logical_and(f_min_idx, f_max_idx)
32
+ f = f[idx]
33
+ xl = [*x]
34
+ for i, xi in enumerate(xl):
35
+ xl[i] = xi[..., idx]
36
+ return f, *xl
37
+
38
+
39
+ def reduce_freq_bands(f, x, bands, reduce_func=np.sum):
40
+ x_bands = dict()
41
+ for k, lims in bands.items():
42
+ assert isinstance(k, str)
43
+ assert len(lims) == 2 and lims[0] <= lims[1]
44
+ assert lims[0] >= f[0] and lims[1] <= f[-1]
45
+ mask = np.logical_and(f >= lims[0], f < lims[1])
46
+ xf = x[..., mask]
47
+ x_bands[k] = reduce_func(xf, axis=-1)
48
+ return x_bands
@@ -9,13 +9,13 @@ from pathlib import Path
9
9
 
10
10
  import pandas as pd
11
11
  from joblib import Parallel, delayed
12
-
13
12
  from mne.io import read_info
13
+
14
14
  from braindecode.datautil.serialization import _load_kwargs_json
15
15
 
16
16
  from .datasets import (
17
- FeaturesDataset,
18
17
  FeaturesConcatDataset,
18
+ FeaturesDataset,
19
19
  )
20
20
 
21
21
 
eegdash/features/utils.py CHANGED
@@ -1,18 +1,20 @@
1
- from typing import Dict, List
2
- from collections.abc import Callable
3
1
  import copy
2
+ from collections.abc import Callable
3
+ from typing import Dict, List
4
+
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  from joblib import Parallel, delayed
7
- from tqdm import tqdm
8
8
  from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
9
11
  from braindecode.datasets.base import (
12
+ BaseConcatDataset,
10
13
  EEGWindowsDataset,
11
14
  WindowsDataset,
12
- BaseConcatDataset,
13
15
  )
14
16
 
15
- from .datasets import FeaturesDataset, FeaturesConcatDataset
17
+ from .datasets import FeaturesConcatDataset, FeaturesDataset
16
18
  from .extractors import FeatureExtractor
17
19
 
18
20
 
@@ -53,7 +55,7 @@ def _extract_features_from_windowsdataset(
53
55
  metadata.reset_index(drop=True, inplace=True)
54
56
  metadata.drop("orig_index", axis=1, inplace=True)
55
57
 
56
- # FUTURE: truely support WindowsDataset objects
58
+ # FUTURE: truly support WindowsDataset objects
57
59
  return FeaturesDataset(
58
60
  features_df,
59
61
  metadata=metadata,