eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,6 @@
1
+ from .complexity import *
2
+ from .connectivity import *
3
+ from .csp import *
4
+ from .dimensionality import *
5
+ from .signal import *
6
+ from .spectral import *
@@ -0,0 +1,96 @@
1
+ import numba as nb
2
+ import numpy as np
3
+ from sklearn.neighbors import KDTree
4
+
5
+ from ..decorators import FeaturePredecessor, univariate_feature
6
+ from ..extractors import FeatureExtractor
7
+
8
+ __all__ = [
9
+ "EntropyFeatureExtractor",
10
+ "complexity_approx_entropy",
11
+ "complexity_sample_entropy",
12
+ "complexity_svd_entropy",
13
+ "complexity_lempel_ziv",
14
+ ]
15
+
16
+
17
+ @nb.njit(cache=True, fastmath=True)
18
+ def _create_embedding(x, dim, lag):
19
+ y = np.empty(((x.shape[-1] - dim + 1) // lag, dim))
20
+ for i in range(0, x.shape[-1] - dim + 1, lag):
21
+ y[i] = x[i : i + dim]
22
+ return y
23
+
24
+
25
+ def _channel_app_samp_entropy_counts(x, m, r, l):
26
+ x_emb = _create_embedding(x, m, l)
27
+ kdtree = KDTree(x_emb, metric="chebyshev")
28
+ return kdtree.query_radius(x_emb, r, count_only=True)
29
+
30
+
31
+ class EntropyFeatureExtractor(FeatureExtractor):
32
+ def preprocess(self, x, m=2, r=0.2, l=1):
33
+ rr = r * x.std(axis=-1)
34
+ counts_m = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // l))
35
+ counts_mp1 = np.empty((*x.shape[:-1], (x.shape[-1] - m) // l))
36
+ for i in np.ndindex(x.shape[:-1]):
37
+ counts_m[*i, :] = _channel_app_samp_entropy_counts(x[i], m, rr[i], l)
38
+ counts_mp1[*i, :] = _channel_app_samp_entropy_counts(x[i], m + 1, rr[i], l)
39
+ return counts_m, counts_mp1
40
+
41
+
42
+ @FeaturePredecessor(EntropyFeatureExtractor)
43
+ @univariate_feature
44
+ def complexity_approx_entropy(counts_m, counts_mp1):
45
+ phi_m = np.log(counts_m / counts_m.shape[-1]).mean(axis=-1)
46
+ phi_mp1 = np.log(counts_mp1 / counts_mp1.shape[-1]).mean(axis=-1)
47
+ return phi_m - phi_mp1
48
+
49
+
50
+ @FeaturePredecessor(EntropyFeatureExtractor)
51
+ @univariate_feature
52
+ def complexity_sample_entropy(counts_m, counts_mp1):
53
+ A = np.sum(counts_mp1 - 1, axis=-1)
54
+ B = np.sum(counts_m - 1, axis=-1)
55
+ return -np.log(A / B)
56
+
57
+
58
+ @univariate_feature
59
+ def complexity_svd_entropy(x, m=10, tau=1):
60
+ x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
61
+ for i in np.ndindex(x.shape[:-1]):
62
+ x_emb[*i, :, :] = _create_embedding(x[i], m, tau)
63
+ s = np.linalg.svdvals(x_emb)
64
+ s /= s.sum(axis=-1, keepdims=True)
65
+ return -np.sum(s * np.log(s), axis=-1)
66
+
67
+
68
+ @univariate_feature
69
+ @nb.njit(cache=True, fastmath=True)
70
+ def complexity_lempel_ziv(x, threshold=None):
71
+ lzc = np.empty(x.shape[:-1])
72
+ for i in np.ndindex(x.shape[:-1]):
73
+ t = np.median(x[i]) if threshold is None else threshold
74
+ s = x[i] > t
75
+ n = s.shape[0]
76
+ j, k, l = 0, 1, 1
77
+ k_max = 1
78
+ lzc[i] = 1
79
+ while True:
80
+ if s[j + k - 1] == s[l + k - 1]:
81
+ k += 1
82
+ if l + k > n:
83
+ lzc[i] += 1
84
+ break
85
+ else:
86
+ k_max = np.maximum(k, k_max)
87
+ j += 1
88
+ if j == l:
89
+ lzc[i] += 1
90
+ l += k_max
91
+ if l + 1 > n:
92
+ break
93
+ j, k, k_max = 0, 1, 1
94
+ else:
95
+ k = 1
96
+ return lzc
@@ -0,0 +1,59 @@
1
+ from itertools import chain
2
+
3
+ import numpy as np
4
+ from scipy.signal import csd
5
+
6
+ from ..decorators import FeaturePredecessor, bivariate_feature
7
+ from ..extractors import BivariateFeature, FeatureExtractor
8
+ from . import utils
9
+
10
+ __all__ = [
11
+ "CoherenceFeatureExtractor",
12
+ "connectivity_magnitude_square_coherence",
13
+ "connectivity_imaginary_coherence",
14
+ "connectivity_lagged_coherence",
15
+ ]
16
+
17
+
18
+ class CoherenceFeatureExtractor(FeatureExtractor):
19
+ def preprocess(self, x, **kwargs):
20
+ f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
21
+ f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
22
+ assert "fs" in kwargs and "nperseg" in kwargs
23
+ kwargs["axis"] = -1
24
+ n = x.shape[1]
25
+ idx_x, idx_y = BivariateFeature.get_pair_iterators(n)
26
+ ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y))
27
+ f, s = csd(x[:, ix], x[:, iy], **kwargs)
28
+ f_min, f_max = utils.get_valid_freq_band(
29
+ kwargs["fs"], x.shape[-1], f_min, f_max
30
+ )
31
+ f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max)
32
+ p, sxy = np.split(s, [n], axis=1)
33
+ sxx, syy = p[:, idx_x].real, p[:, idx_y].real
34
+ c = sxy / np.sqrt(sxx * syy)
35
+ return f, c
36
+
37
+
38
+ @FeaturePredecessor(CoherenceFeatureExtractor)
39
+ @bivariate_feature
40
+ def connectivity_magnitude_square_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
41
+ # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
42
+ coher = c.real**2 + c.imag**2
43
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
44
+
45
+
46
+ @FeaturePredecessor(CoherenceFeatureExtractor)
47
+ @bivariate_feature
48
+ def connectivity_imaginary_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
49
+ # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
50
+ coher = c.imag
51
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
52
+
53
+
54
+ @FeaturePredecessor(CoherenceFeatureExtractor)
55
+ @bivariate_feature
56
+ def connectivity_lagged_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
57
+ # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
58
+ coher = c.imag / np.sqrt(1 - c.real)
59
+ return utils.reduce_freq_bands(f, coher, bands, np.mean)
@@ -0,0 +1,101 @@
1
+ import numba as nb
2
+ import numpy as np
3
+ import scipy
4
+ import scipy.linalg
5
+
6
+ from ..decorators import multivariate_feature
7
+ from ..extractors import FitableFeature
8
+
9
+ __all__ = [
10
+ "CommonSpatialPattern",
11
+ ]
12
+
13
+
14
+ @nb.njit(cache=True, fastmath=True, parallel=True)
15
+ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov):
16
+ alpha2 = x_count / count
17
+ alpha1 = 1 - alpha2
18
+ cov[:] = alpha1 * (cov + np.outer(mean, mean))
19
+ cov[:] += alpha2 * (x_cov + np.outer(x_mean, x_mean))
20
+ mean[:] = alpha1 * mean + alpha2 * x_mean
21
+ cov[:] -= np.outer(mean, mean)
22
+
23
+
24
+ @multivariate_feature
25
+ class CommonSpatialPattern(FitableFeature):
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def clear(self):
30
+ self._labels = None
31
+ self._counts = np.array([0, 0])
32
+ self._means = np.array([None, None])
33
+ self._covs = np.array([None, None])
34
+ self._mean = None
35
+ self._eigvals = None
36
+ self._weights = None
37
+
38
+ def _update_labels(self, labels):
39
+ if self._labels is None:
40
+ self._labels = labels
41
+ else:
42
+ for label in labels:
43
+ if label not in self._labels:
44
+ self._labels = np.append(self._labels, label)
45
+ assert self._labels.shape[0] < 3
46
+ return self._labels
47
+
48
+ def _update_stats(self, l, x):
49
+ x_count, x_mean, x_cov = x.shape[0], x.mean(axis=0), np.cov(x.T, ddof=0)
50
+ if self._counts[l] == 0:
51
+ self._counts[l] = x_count
52
+ self._means[l] = x_mean
53
+ self._covs[l] = x_cov
54
+ else:
55
+ self._counts[l] += x_count
56
+ _update_mean_cov(
57
+ self._counts[l], self._means[l], self._covs[l], x_count, x_mean, x_cov
58
+ )
59
+
60
+ def partial_fit(self, x, y=None):
61
+ labels = self._update_labels(np.unique(y))
62
+ for i, l in enumerate(labels):
63
+ ind = (y == l).nonzero()[0]
64
+ if ind.shape[0] > 0:
65
+ xl = self.transform_input(x[ind])
66
+ self._update_stats(i, xl)
67
+
68
+ @staticmethod
69
+ def transform_input(x):
70
+ return x.swapaxes(1, 2).reshape(-1, x.shape[1])
71
+
72
+ def fit(self):
73
+ alphas = self._counts / self._counts.sum()
74
+ self._mean = np.sum(alphas * self._means)
75
+ for l in range(len(self._labels)):
76
+ self._covs[l] *= self._counts[l] / (self._counts[1] - 1)
77
+ l, w = scipy.linalg.eig(self._covs[0], self._covs[0] + self._covs[1])
78
+ l = l.real
79
+ ind = l > 0
80
+ l, w = l[ind], w[:, ind]
81
+ ord = np.abs(l - 0.5).argsort()[::-1]
82
+ self._eigvals = l[ord]
83
+ self._weights = w[:, ord]
84
+ super().fit()
85
+
86
+ def __call__(self, x, n_select=None, crit_select=None):
87
+ super().__call__()
88
+ w = self._weights
89
+ if n_select:
90
+ w = w[:, :n_select]
91
+ if crit_select:
92
+ sel = 0.5 - np.abs(self._eigvals - 0.5) < crit_select
93
+ w = w[:, sel]
94
+ if w.shape[-1] == 0:
95
+ raise RuntimeError(
96
+ "CSP weights selection criterion is too strict,"
97
+ + "all weights were filtered out."
98
+ )
99
+ proj = (self.transform_input(x) - self._mean) @ w
100
+ proj = proj.reshape(x.shape[0], x.shape[2], -1).mean(axis=1)
101
+ return {f"{i}": proj[:, i] for i in range(proj.shape[-1])}
@@ -0,0 +1,107 @@
1
+ import numba as nb
2
+ import numpy as np
3
+ from scipy import special
4
+
5
+ from ..decorators import univariate_feature
6
+ from .signal import signal_zero_crossings
7
+
8
+ __all__ = [
9
+ "dimensionality_higuchi_fractal_dim",
10
+ "dimensionality_petrosian_fractal_dim",
11
+ "dimensionality_katz_fractal_dim",
12
+ "dimensionality_hurst_exp",
13
+ "dimensionality_detrended_fluctuation_analysis",
14
+ ]
15
+
16
+
17
+ @univariate_feature
18
+ @nb.njit(cache=True, fastmath=True)
19
+ def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
20
+ N = x.shape[-1]
21
+ hfd = np.empty(x.shape[:-1])
22
+ log_k = np.vstack((-np.log(np.arange(1, k_max + 1)), np.ones(k_max))).T
23
+ L_km = np.empty(k_max)
24
+ L_k = np.empty(k_max)
25
+ for i in np.ndindex(x.shape[:-1]):
26
+ for k in range(1, k_max + 1):
27
+ for m in range(k):
28
+ L_km[m] = np.mean(np.abs(np.diff(x[*i, m:], n=k)))
29
+ L_k[k - 1] = (N - 1) * np.sum(L_km[:k]) / (k**3)
30
+ L_k = np.maximum(L_k, eps)
31
+ hfd[i] = np.linalg.lstsq(log_k, np.log(L_k))[0][0]
32
+ return hfd
33
+
34
+
35
+ @univariate_feature
36
+ def dimensionality_petrosian_fractal_dim(x):
37
+ nd = signal_zero_crossings(np.diff(x, axis=-1))
38
+ log_n = np.log(x.shape[-1])
39
+ return log_n / (np.log(nd) + log_n)
40
+
41
+
42
+ @univariate_feature
43
+ def dimensionality_katz_fractal_dim(x):
44
+ dists = np.abs(np.diff(x, axis=-1))
45
+ L = dists.sum(axis=-1)
46
+ a = dists.mean(axis=-1)
47
+ log_n = np.log(L / a)
48
+ d = np.abs(x[..., 1:] - x[..., 0, None]).max(axis=-1)
49
+ return log_n / (np.log(d / L) + log_n)
50
+
51
+
52
+ @univariate_feature
53
+ @nb.njit(cache=True, fastmath=True)
54
+ def _hurst_exp(x, ns, a, gamma_ratios, log_n):
55
+ h = np.empty(x.shape[:-1])
56
+ rs = np.empty((ns.shape[0], x.shape[-1] // ns[0]))
57
+ log_rs = np.empty(ns.shape[0])
58
+ for i in np.ndindex(x.shape[:-1]):
59
+ t0 = 0
60
+ for j, n in enumerate(ns):
61
+ for k, t0 in enumerate(range(0, x.shape[-1], n)):
62
+ xj = x[i][t0 : t0 + n]
63
+ m = np.mean(xj)
64
+ y = xj - m
65
+ z = np.cumsum(y)
66
+ r = np.ptp(z)
67
+ s = np.sqrt(np.mean(y**2))
68
+ if s == 0.0:
69
+ rs[j, k] = np.nan
70
+ else:
71
+ rs[j, k] = r / s
72
+ log_rs[j] = np.log(np.nanmean(rs[j, : x.shape[1] // n]))
73
+ log_rs[j] -= np.log(np.sum(np.sqrt((n - a[:n]) / a[:n])) * gamma_ratios[j])
74
+ h[i] = 0.5 + np.linalg.lstsq(log_n, log_rs)[0][0]
75
+ return h
76
+
77
+
78
+ @univariate_feature
79
+ def dimensionality_hurst_exp(x):
80
+ ns = np.unique(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1)).astype(int))
81
+ idx = ns > 340
82
+ gamma_ratios = np.empty(ns.shape[0])
83
+ gamma_ratios[idx] = 1 / np.sqrt(ns[idx] / 2)
84
+ gamma_ratios[~idx] = special.gamma((ns[~idx] - 1) / 2) / special.gamma(ns[~idx] / 2)
85
+ gamma_ratios /= np.sqrt(np.pi)
86
+ log_n = np.vstack((np.log(ns), np.ones(ns.shape[0]))).T
87
+ a = np.arange(1, ns[-1], dtype=float)
88
+ return _hurst_exp(x, ns, a, gamma_ratios, log_n)
89
+
90
+
91
+ @univariate_feature
92
+ @nb.njit(cache=True, fastmath=True)
93
+ def dimensionality_detrended_fluctuation_analysis(x):
94
+ ns = np.unique(np.floor(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1))))
95
+ a = np.vstack((np.arange(ns[-1]), np.ones(int(ns[-1])))).T
96
+ log_n = np.vstack((np.log(ns), np.ones(ns.shape[0]))).T
97
+ Fn = np.empty(ns.shape[0])
98
+ alpha = np.empty(x.shape[:-1])
99
+ for i in np.ndindex(x.shape[:-1]):
100
+ X = np.cumsum(x[i] - np.mean(x[i]))
101
+ for j, n in enumerate(ns):
102
+ n = int(n)
103
+ Z = np.reshape(X[: n * (X.shape[0] // n)], (n, X.shape[0] // n))
104
+ Fni2 = np.linalg.lstsq(a[:n], Z)[1] / n
105
+ Fn[j] = np.sqrt(np.mean(Fni2))
106
+ alpha[i] = np.linalg.lstsq(log_n, np.log(Fn))[0][0]
107
+ return alpha
@@ -0,0 +1,103 @@
1
+ import numbers
2
+
3
+ import numpy as np
4
+ from scipy import stats
5
+
6
+ from ..decorators import univariate_feature
7
+
8
+ __all__ = [
9
+ "signal_mean",
10
+ "signal_variance",
11
+ "signal_skewness",
12
+ "signal_kurtosis",
13
+ "signal_std",
14
+ "signal_root_mean_square",
15
+ "signal_peak_to_peak",
16
+ "signal_quantile",
17
+ "signal_zero_crossings",
18
+ "signal_line_length",
19
+ "signal_hjorth_activity",
20
+ "signal_hjorth_mobility",
21
+ "signal_hjorth_complexity",
22
+ "signal_decorrelation_time",
23
+ ]
24
+
25
+
26
+ @univariate_feature
27
+ def signal_mean(x):
28
+ return x.mean(axis=-1)
29
+
30
+
31
+ @univariate_feature
32
+ def signal_variance(x, **kwargs):
33
+ return x.var(axis=-1, **kwargs)
34
+
35
+
36
+ @univariate_feature
37
+ def signal_std(x, **kwargs):
38
+ return x.std(axis=-1, **kwargs)
39
+
40
+
41
+ @univariate_feature
42
+ def signal_skewness(x, **kwargs):
43
+ return stats.skew(x, axis=x.ndim - 1, **kwargs)
44
+
45
+
46
+ @univariate_feature
47
+ def signal_kurtosis(x, **kwargs):
48
+ return stats.kurtosis(x, axis=x.ndim - 1, **kwargs)
49
+
50
+
51
+ @univariate_feature
52
+ def signal_root_mean_square(x):
53
+ return np.sqrt(np.power(x, 2).mean(axis=-1))
54
+
55
+
56
+ @univariate_feature
57
+ def signal_peak_to_peak(x, **kwargs):
58
+ return np.ptp(x, axis=-1, **kwargs)
59
+
60
+
61
+ @univariate_feature
62
+ def signal_quantile(x, q: numbers.Number = 0.5, **kwargs):
63
+ return np.quantile(x, q=q, axis=-1, **kwargs)
64
+
65
+
66
+ @univariate_feature
67
+ def signal_line_length(x):
68
+ return np.abs(np.diff(x, axis=-1)).mean(axis=-1)
69
+
70
+
71
+ @univariate_feature
72
+ def signal_zero_crossings(x, threshold=1e-15):
73
+ zero_ind = np.logical_and(x > -threshold, x < threshold)
74
+ zero_cross = np.diff(zero_ind, axis=-1).astype(int).sum(axis=-1)
75
+ y = x.copy()
76
+ y[zero_ind] = 0
77
+ zero_cross += np.sum(np.signbit(y[..., :-1]) != np.signbit(y[..., 1:]), axis=-1)
78
+ return zero_cross
79
+
80
+
81
+ @univariate_feature
82
+ def signal_hjorth_mobility(x):
83
+ return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1)
84
+
85
+
86
+ @univariate_feature
87
+ def signal_hjorth_complexity(x):
88
+ return np.diff(x, 2, axis=-1).std(axis=-1) / x.std(axis=-1)
89
+
90
+
91
+ @univariate_feature
92
+ def signal_decorrelation_time(x, fs=1):
93
+ f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1)
94
+ ac = np.fft.ifft(f.real**2 + f.imag**2, axis=-1)[..., : x.shape[-1] // 2]
95
+ dct = np.empty(x.shape[:-1])
96
+ for i in np.ndindex(x.shape[:-1]):
97
+ dct[i] = np.searchsorted(ac[i] <= 0, True)
98
+ return dct / fs
99
+
100
+
101
+ # ================================= Aliases =================================
102
+
103
+ signal_hjorth_activity = signal_variance
@@ -0,0 +1,116 @@
1
+ import numba as nb
2
+ import numpy as np
3
+ from scipy.signal import welch
4
+
5
+ from ..decorators import FeaturePredecessor, univariate_feature
6
+ from ..extractors import FeatureExtractor
7
+ from . import utils
8
+
9
+ __all__ = [
10
+ "SpectralFeatureExtractor",
11
+ "NormalizedSpectralFeatureExtractor",
12
+ "DBSpectralFeatureExtractor",
13
+ "spectral_root_total_power",
14
+ "spectral_moment",
15
+ "spectral_entropy",
16
+ "spectral_edge",
17
+ "spectral_slope",
18
+ "spectral_bands_power",
19
+ "spectral_hjorth_activity",
20
+ "spectral_hjorth_mobility",
21
+ "spectral_hjorth_complexity",
22
+ ]
23
+
24
+
25
+ class SpectralFeatureExtractor(FeatureExtractor):
26
+ def preprocess(self, x, **kwargs):
27
+ f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
28
+ f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
29
+ assert "fs" in kwargs
30
+ kwargs["axis"] = -1
31
+ f, p = welch(x, **kwargs)
32
+ f_min, f_max = utils.get_valid_freq_band(
33
+ kwargs["fs"], x.shape[-1], f_min, f_max
34
+ )
35
+ f, p = utils.slice_freq_band(f, p, f_min=f_min, f_max=f_max)
36
+ return f, p
37
+
38
+
39
+ @FeaturePredecessor(SpectralFeatureExtractor)
40
+ class NormalizedSpectralFeatureExtractor(FeatureExtractor):
41
+ def preprocess(self, *x):
42
+ return (*x[:-1], x[-1] / x[-1].sum(axis=-1, keepdims=True))
43
+
44
+
45
+ @FeaturePredecessor(SpectralFeatureExtractor)
46
+ class DBSpectralFeatureExtractor(FeatureExtractor):
47
+ def preprocess(self, *x, eps=1e-15):
48
+ return (*x[:-1], 10 * np.log10(x[-1] + eps))
49
+
50
+
51
+ @FeaturePredecessor(SpectralFeatureExtractor)
52
+ @univariate_feature
53
+ def spectral_root_total_power(f, p):
54
+ return np.sqrt(p.sum(axis=-1))
55
+
56
+
57
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
58
+ @univariate_feature
59
+ def spectral_moment(f, p):
60
+ return np.sum(f * p, axis=-1)
61
+
62
+
63
+ @FeaturePredecessor(SpectralFeatureExtractor)
64
+ @univariate_feature
65
+ def spectral_hjorth_activity(f, p):
66
+ return np.sum(p, axis=-1)
67
+
68
+
69
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
70
+ @univariate_feature
71
+ def spectral_hjorth_mobility(f, p):
72
+ return np.sqrt(np.sum(np.power(f, 2) * p, axis=-1))
73
+
74
+
75
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
76
+ @univariate_feature
77
+ def spectral_hjorth_complexity(f, p):
78
+ return np.sqrt(np.sum(np.power(f, 4) * p, axis=-1))
79
+
80
+
81
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
82
+ @univariate_feature
83
+ def spectral_entropy(f, p):
84
+ idx = p > 0
85
+ plogp = np.zeros_like(p)
86
+ plogp[idx] = p[idx] * np.log(p[idx])
87
+ return -np.sum(plogp, axis=-1)
88
+
89
+
90
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
91
+ @univariate_feature
92
+ @nb.njit(cache=True, fastmath=True)
93
+ def spectral_edge(f, p, edge=0.9):
94
+ se = np.empty(p.shape[:-1])
95
+ for i in np.ndindex(p.shape[:-1]):
96
+ se[i] = f[np.searchsorted(np.cumsum(p[i]), edge)]
97
+ return se
98
+
99
+
100
+ @FeaturePredecessor(DBSpectralFeatureExtractor)
101
+ @univariate_feature
102
+ def spectral_slope(f, p):
103
+ log_f = np.vstack((np.log(f), np.ones(f.shape[0]))).T
104
+ r = np.linalg.lstsq(log_f, p.reshape(-1, p.shape[-1]).T)[0]
105
+ r = r.reshape(2, *p.shape[:-1])
106
+ return {"exp": r[0], "int": r[1]}
107
+
108
+
109
+ @FeaturePredecessor(
110
+ SpectralFeatureExtractor,
111
+ NormalizedSpectralFeatureExtractor,
112
+ DBSpectralFeatureExtractor,
113
+ )
114
+ @univariate_feature
115
+ def spectral_bands_power(f, p, bands=utils.DEFAULT_FREQ_BANDS):
116
+ return utils.reduce_freq_bands(f, p, bands, np.sum)
@@ -0,0 +1,48 @@
1
+ import numpy as np
2
+
3
+ DEFAULT_FREQ_BANDS = {
4
+ "delta": (1, 4.5),
5
+ "theta": (4.5, 8),
6
+ "alpha": (8, 12),
7
+ "beta": (12, 30),
8
+ }
9
+
10
+
11
+ def get_valid_freq_band(fs, n, f_min=None, f_max=None):
12
+ f0 = 2 * fs / n
13
+ f1 = fs / 2
14
+ if f_min is None:
15
+ f_min = f0
16
+ else:
17
+ assert f_min >= f0
18
+ if f_max is None:
19
+ f_max = f1
20
+ else:
21
+ assert f_max <= f1
22
+ return f_min, f_max
23
+
24
+
25
+ def slice_freq_band(f, *x, f_min=None, f_max=None):
26
+ if f_min is None and f_max is None:
27
+ return f, *x
28
+ else:
29
+ f_min_idx = f >= f_min if f_min is not None else True
30
+ f_max_idx = f <= f_max if f_max is not None else True
31
+ idx = np.logical_and(f_min_idx, f_max_idx)
32
+ f = f[idx]
33
+ xl = [*x]
34
+ for i, xi in enumerate(xl):
35
+ xl[i] = xi[..., idx]
36
+ return f, *xl
37
+
38
+
39
+ def reduce_freq_bands(f, x, bands, reduce_func=np.sum):
40
+ x_bands = dict()
41
+ for k, lims in bands.items():
42
+ assert isinstance(k, str)
43
+ assert len(lims) == 2 and lims[0] <= lims[1]
44
+ assert lims[0] >= f[0] and lims[1] <= f[-1]
45
+ mask = np.logical_and(f >= lims[0], f < lims[1])
46
+ xf = x[..., mask]
47
+ x_bands[k] = reduce_func(xf, axis=-1)
48
+ return x_bands