eegdash 0.0.9__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +8 -1
- eegdash/api.py +690 -0
- eegdash/data_config.py +33 -27
- eegdash/data_utils.py +365 -222
- eegdash/dataset.py +60 -0
- eegdash/features/__init__.py +46 -18
- eegdash/features/datasets.py +62 -23
- eegdash/features/decorators.py +14 -6
- eegdash/features/extractors.py +22 -22
- eegdash/features/feature_bank/__init__.py +3 -3
- eegdash/features/feature_bank/complexity.py +6 -3
- eegdash/features/feature_bank/connectivity.py +16 -56
- eegdash/features/feature_bank/csp.py +3 -4
- eegdash/features/feature_bank/dimensionality.py +8 -5
- eegdash/features/feature_bank/signal.py +30 -4
- eegdash/features/feature_bank/spectral.py +10 -28
- eegdash/features/feature_bank/utils.py +48 -0
- eegdash/features/inspect.py +48 -0
- eegdash/features/serialization.py +4 -5
- eegdash/features/utils.py +9 -7
- eegdash/preprocessing.py +65 -0
- eegdash/utils.py +11 -0
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/METADATA +67 -20
- eegdash-0.2.0.dist-info/RECORD +27 -0
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/WHEEL +1 -1
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/licenses/LICENSE +1 -0
- eegdash/main.py +0 -359
- eegdash-0.0.9.dist-info/RECORD +0 -22
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/top_level.txt +0 -0
eegdash/dataset.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from .api import EEGDashDataset
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EEGChallengeDataset(EEGDashDataset):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
release: str = "R5",
|
|
8
|
+
cache_dir: str = ".eegdash_cache",
|
|
9
|
+
s3_bucket: str | None = "s3://nmdatasets/NeurIPS25/R5_L100",
|
|
10
|
+
**kwargs,
|
|
11
|
+
):
|
|
12
|
+
"""Create a new EEGDashDataset from a given query or local BIDS dataset directory
|
|
13
|
+
and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
|
|
14
|
+
instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
query : dict | None
|
|
19
|
+
Optionally a dictionary that specifies the query to be executed; see
|
|
20
|
+
EEGDash.find() for details on the query format.
|
|
21
|
+
data_dir : str | list[str] | None
|
|
22
|
+
Optionally a string or a list of strings specifying one or more local
|
|
23
|
+
BIDS dataset directories from which to load the EEG data files. Exactly one
|
|
24
|
+
of query or data_dir must be provided.
|
|
25
|
+
dataset : str | list[str] | None
|
|
26
|
+
If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
|
|
27
|
+
description_fields : list[str]
|
|
28
|
+
A list of fields to be extracted from the dataset records
|
|
29
|
+
and included in the returned data description(s). Examples are typical
|
|
30
|
+
subject metadata fields such as "subject", "session", "run", "task", etc.;
|
|
31
|
+
see also data_config.description_fields for the default set of fields.
|
|
32
|
+
cache_dir : str
|
|
33
|
+
A directory where the dataset will be cached locally.
|
|
34
|
+
s3_bucket : str | None
|
|
35
|
+
An optional S3 bucket URI to use instead of the
|
|
36
|
+
default OpenNeuro bucket for loading data files.
|
|
37
|
+
kwargs : dict
|
|
38
|
+
Additional keyword arguments to be passed to the EEGDashBaseDataset
|
|
39
|
+
constructor.
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
dsnumber_release_map = {
|
|
43
|
+
"R11": "ds005516",
|
|
44
|
+
"R10": "ds005515",
|
|
45
|
+
"R9": "ds005514",
|
|
46
|
+
"R8": "ds005512",
|
|
47
|
+
"R7": "ds005511",
|
|
48
|
+
"R6": "ds005510",
|
|
49
|
+
"R4": "ds005508",
|
|
50
|
+
"R5": "ds005509",
|
|
51
|
+
"R3": "ds005507",
|
|
52
|
+
"R2": "ds005506",
|
|
53
|
+
"R1": "ds005505",
|
|
54
|
+
}
|
|
55
|
+
super().__init__(
|
|
56
|
+
query={"dataset": dsnumber_release_map[release]},
|
|
57
|
+
cache_dir=cache_dir,
|
|
58
|
+
s3_bucket=s3_bucket,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
eegdash/features/__init__.py
CHANGED
|
@@ -1,25 +1,53 @@
|
|
|
1
|
-
|
|
2
|
-
from .
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
from .datasets import FeaturesConcatDataset, FeaturesDataset
|
|
2
|
+
from .decorators import (
|
|
3
|
+
FeatureKind,
|
|
4
|
+
FeaturePredecessor,
|
|
5
|
+
bivariate_feature,
|
|
6
|
+
multivariate_feature,
|
|
7
|
+
univariate_feature,
|
|
8
|
+
)
|
|
6
9
|
from .extractors import (
|
|
7
|
-
FeatureExtractor,
|
|
8
|
-
FitableFeature,
|
|
9
|
-
UnivariateFeature,
|
|
10
10
|
BivariateFeature,
|
|
11
11
|
DirectedBivariateFeature,
|
|
12
|
+
FeatureExtractor,
|
|
12
13
|
MultivariateFeature,
|
|
14
|
+
TrainableFeature,
|
|
15
|
+
UnivariateFeature,
|
|
13
16
|
)
|
|
14
|
-
from .
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
from .feature_bank import * # noqa: F401
|
|
18
|
+
from .inspect import (
|
|
19
|
+
get_all_feature_extractors,
|
|
20
|
+
get_all_feature_kinds,
|
|
21
|
+
get_all_features,
|
|
22
|
+
get_feature_kind,
|
|
23
|
+
get_feature_predecessors,
|
|
24
|
+
)
|
|
25
|
+
from .serialization import load_features_concat_dataset
|
|
26
|
+
from .utils import (
|
|
27
|
+
extract_features,
|
|
28
|
+
fit_feature_extractors,
|
|
21
29
|
)
|
|
22
|
-
from .utils import extract_features, fit_feature_extractors
|
|
23
30
|
|
|
24
|
-
|
|
25
|
-
|
|
31
|
+
__all__ = [
|
|
32
|
+
"FeaturesConcatDataset",
|
|
33
|
+
"FeaturesDataset",
|
|
34
|
+
"FeatureKind",
|
|
35
|
+
"FeaturePredecessor",
|
|
36
|
+
"bivariate_feature",
|
|
37
|
+
"multivariate_feature",
|
|
38
|
+
"univariate_feature",
|
|
39
|
+
"BivariateFeature",
|
|
40
|
+
"DirectedBivariateFeature",
|
|
41
|
+
"FeatureExtractor",
|
|
42
|
+
"MultivariateFeature",
|
|
43
|
+
"TrainableFeature",
|
|
44
|
+
"UnivariateFeature",
|
|
45
|
+
"get_all_feature_extractors",
|
|
46
|
+
"get_all_feature_kinds",
|
|
47
|
+
"get_all_features",
|
|
48
|
+
"get_feature_kind",
|
|
49
|
+
"get_feature_predecessors",
|
|
50
|
+
"load_features_concat_dataset",
|
|
51
|
+
"extract_features",
|
|
52
|
+
"fit_feature_extractors",
|
|
53
|
+
]
|
eegdash/features/datasets.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
import shutil
|
|
5
6
|
import warnings
|
|
6
|
-
from
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Dict, List
|
|
9
|
+
|
|
8
10
|
import numpy as np
|
|
9
11
|
import pandas as pd
|
|
10
12
|
from joblib import Parallel, delayed
|
|
13
|
+
|
|
11
14
|
from braindecode.datasets.base import (
|
|
12
|
-
EEGWindowsDataset,
|
|
13
15
|
BaseConcatDataset,
|
|
16
|
+
EEGWindowsDataset,
|
|
14
17
|
_create_description,
|
|
15
18
|
)
|
|
16
19
|
|
|
@@ -30,6 +33,7 @@ class FeaturesDataset(EEGWindowsDataset):
|
|
|
30
33
|
Holds additional description about the continuous signal / subject.
|
|
31
34
|
transform : callable | None
|
|
32
35
|
On-the-fly transform applied to the example before it is returned.
|
|
36
|
+
|
|
33
37
|
"""
|
|
34
38
|
|
|
35
39
|
def __init__(
|
|
@@ -92,10 +96,12 @@ def _compute_stats(
|
|
|
92
96
|
return tuple(res)
|
|
93
97
|
|
|
94
98
|
|
|
95
|
-
def _pooled_var(counts, means, variances, ddof):
|
|
99
|
+
def _pooled_var(counts, means, variances, ddof, ddof_in=None):
|
|
100
|
+
if ddof_in is None:
|
|
101
|
+
ddof_in = ddof
|
|
96
102
|
count = counts.sum(axis=0)
|
|
97
103
|
mean = np.sum((counts / count) * means, axis=0)
|
|
98
|
-
var = np.sum(((counts -
|
|
104
|
+
var = np.sum(((counts - ddof_in) / (count - ddof)) * variances, axis=0)
|
|
99
105
|
var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
|
|
100
106
|
var[:] -= (count / (count - ddof)) * (mean**2)
|
|
101
107
|
var[:] = var.clip(min=0)
|
|
@@ -153,6 +159,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
153
159
|
splits : dict
|
|
154
160
|
A dictionary with the name of the split (a string) as key and the
|
|
155
161
|
dataset as value.
|
|
162
|
+
|
|
156
163
|
"""
|
|
157
164
|
if isinstance(by, str):
|
|
158
165
|
split_ids = {
|
|
@@ -184,6 +191,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
184
191
|
DataFrame containing as many rows as there are windows in the
|
|
185
192
|
BaseConcatDataset, with the metadata and description information
|
|
186
193
|
for each window.
|
|
194
|
+
|
|
187
195
|
"""
|
|
188
196
|
if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
|
|
189
197
|
raise TypeError(
|
|
@@ -235,6 +243,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
235
243
|
concat. This is useful in the setting of very large datasets, where
|
|
236
244
|
one dataset has to be processed and saved at a time to account for
|
|
237
245
|
its original position.
|
|
246
|
+
|
|
238
247
|
"""
|
|
239
248
|
if len(self.datasets) == 0:
|
|
240
249
|
raise ValueError("Expect at least one dataset")
|
|
@@ -320,25 +329,53 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
320
329
|
json.dump(kwargs, f)
|
|
321
330
|
|
|
322
331
|
def to_dataframe(
|
|
323
|
-
self,
|
|
332
|
+
self,
|
|
333
|
+
include_metadata: bool | str | List[str] = False,
|
|
334
|
+
include_target: bool = False,
|
|
335
|
+
include_crop_inds: bool = False,
|
|
324
336
|
):
|
|
325
|
-
if
|
|
337
|
+
if (
|
|
338
|
+
not isinstance(include_metadata, bool)
|
|
339
|
+
or include_metadata
|
|
340
|
+
or include_crop_inds
|
|
341
|
+
):
|
|
342
|
+
include_dataset = False
|
|
343
|
+
if isinstance(include_metadata, bool) and include_metadata:
|
|
344
|
+
include_dataset = True
|
|
345
|
+
cols = self.datasets[0].metadata.columns
|
|
346
|
+
else:
|
|
347
|
+
cols = include_metadata
|
|
348
|
+
if isinstance(cols, bool) and not cols:
|
|
349
|
+
cols = []
|
|
350
|
+
elif isinstance(cols, str):
|
|
351
|
+
cols = [cols]
|
|
352
|
+
cols = set(cols)
|
|
353
|
+
if include_crop_inds:
|
|
354
|
+
cols = {
|
|
355
|
+
"i_dataset",
|
|
356
|
+
"i_window_in_trial",
|
|
357
|
+
"i_start_in_trial",
|
|
358
|
+
"i_stop_in_trial",
|
|
359
|
+
*cols,
|
|
360
|
+
}
|
|
361
|
+
if include_target:
|
|
362
|
+
cols.add("target")
|
|
363
|
+
cols = list(cols)
|
|
364
|
+
include_dataset = "i_dataset" in cols
|
|
365
|
+
if include_dataset:
|
|
366
|
+
cols.remove("i_dataset")
|
|
326
367
|
dataframes = [
|
|
327
|
-
ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
|
|
368
|
+
ds.metadata[cols].join(ds.features, how="right", lsuffix="_metadata")
|
|
328
369
|
for ds in self.datasets
|
|
329
370
|
]
|
|
371
|
+
if include_dataset:
|
|
372
|
+
for i, df in enumerate(dataframes):
|
|
373
|
+
df.insert(loc=0, column="i_dataset", value=i)
|
|
330
374
|
elif include_target:
|
|
331
375
|
dataframes = [
|
|
332
376
|
ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
|
|
333
377
|
for ds in self.datasets
|
|
334
378
|
]
|
|
335
|
-
elif include_crop_inds:
|
|
336
|
-
dataframes = [
|
|
337
|
-
ds.metadata.drop("target", axis="columns").join(
|
|
338
|
-
ds.features, how="right", lsuffix="_metadata"
|
|
339
|
-
)
|
|
340
|
-
for ds in self.datasets
|
|
341
|
-
]
|
|
342
379
|
else:
|
|
343
380
|
dataframes = [ds.features for ds in self.datasets]
|
|
344
381
|
return pd.concat(dataframes, axis=0, ignore_index=True)
|
|
@@ -374,7 +411,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
374
411
|
return_count=True,
|
|
375
412
|
return_mean=True,
|
|
376
413
|
return_var=True,
|
|
377
|
-
ddof=
|
|
414
|
+
ddof=0,
|
|
378
415
|
numeric_only=numeric_only,
|
|
379
416
|
)
|
|
380
417
|
for ds in self.datasets
|
|
@@ -384,11 +421,13 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
384
421
|
np.array([s[1] for s in stats]),
|
|
385
422
|
np.array([s[2] for s in stats]),
|
|
386
423
|
)
|
|
387
|
-
_, _, var = _pooled_var(counts, means, variances, ddof)
|
|
424
|
+
_, _, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
|
|
388
425
|
return pd.Series(var, index=self._numeric_columns())
|
|
389
426
|
|
|
390
|
-
def std(self, ddof=1, numeric_only=False, n_jobs=1):
|
|
391
|
-
return np.sqrt(
|
|
427
|
+
def std(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
|
|
428
|
+
return np.sqrt(
|
|
429
|
+
self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs) + eps
|
|
430
|
+
)
|
|
392
431
|
|
|
393
432
|
def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
|
|
394
433
|
stats = Parallel(n_jobs)(
|
|
@@ -397,7 +436,7 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
397
436
|
return_count=True,
|
|
398
437
|
return_mean=True,
|
|
399
438
|
return_var=True,
|
|
400
|
-
ddof=
|
|
439
|
+
ddof=0,
|
|
401
440
|
numeric_only=numeric_only,
|
|
402
441
|
)
|
|
403
442
|
for ds in self.datasets
|
|
@@ -407,8 +446,8 @@ class FeaturesConcatDataset(BaseConcatDataset):
|
|
|
407
446
|
np.array([s[1] for s in stats]),
|
|
408
447
|
np.array([s[2] for s in stats]),
|
|
409
448
|
)
|
|
410
|
-
_, mean, var = _pooled_var(counts, means, variances, ddof)
|
|
411
|
-
std = np.sqrt(var
|
|
449
|
+
_, mean, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
|
|
450
|
+
std = np.sqrt(var + eps)
|
|
412
451
|
for ds in self.datasets:
|
|
413
452
|
ds.features = (ds.features - mean) / std
|
|
414
453
|
|
eegdash/features/decorators.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
from typing import List, Type
|
|
2
1
|
from collections.abc import Callable
|
|
2
|
+
from typing import List, Type
|
|
3
3
|
|
|
4
4
|
from .extractors import (
|
|
5
|
-
FeatureExtractor,
|
|
6
|
-
UnivariateFeature,
|
|
7
5
|
BivariateFeature,
|
|
8
6
|
DirectedBivariateFeature,
|
|
7
|
+
FeatureExtractor,
|
|
9
8
|
MultivariateFeature,
|
|
9
|
+
UnivariateFeature,
|
|
10
|
+
_get_underlying_func,
|
|
10
11
|
)
|
|
11
|
-
from .extractors import _get_underlying_func
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class FeaturePredecessor:
|
|
@@ -38,6 +38,14 @@ class FeatureKind:
|
|
|
38
38
|
|
|
39
39
|
# Syntax sugar
|
|
40
40
|
univariate_feature = FeatureKind(UnivariateFeature())
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def bivariate_feature(func, directed=False):
|
|
44
|
+
if directed:
|
|
45
|
+
kind = DirectedBivariateFeature()
|
|
46
|
+
else:
|
|
47
|
+
kind = BivariateFeature()
|
|
48
|
+
return FeatureKind(kind)(func)
|
|
49
|
+
|
|
50
|
+
|
|
43
51
|
multivariate_feature = FeatureKind(MultivariateFeature())
|
eegdash/features/extractors.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict
|
|
3
2
|
from collections.abc import Callable
|
|
4
3
|
from functools import partial
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
5
6
|
import numpy as np
|
|
6
7
|
from numba.core.dispatcher import Dispatcher
|
|
7
8
|
|
|
@@ -15,9 +16,9 @@ def _get_underlying_func(func):
|
|
|
15
16
|
return f
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
class
|
|
19
|
+
class TrainableFeature(ABC):
|
|
19
20
|
def __init__(self):
|
|
20
|
-
self.
|
|
21
|
+
self._is_trained = False
|
|
21
22
|
self.clear()
|
|
22
23
|
|
|
23
24
|
@abstractmethod
|
|
@@ -38,12 +39,12 @@ class FitableFeature(ABC):
|
|
|
38
39
|
)
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
class FeatureExtractor(
|
|
42
|
+
class FeatureExtractor(TrainableFeature):
|
|
42
43
|
def __init__(
|
|
43
44
|
self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
|
|
44
45
|
):
|
|
45
46
|
self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
|
|
46
|
-
self.
|
|
47
|
+
self._is_trainable = self._check_is_trainable(feature_extractors)
|
|
47
48
|
super().__init__()
|
|
48
49
|
|
|
49
50
|
# bypassing FeaturePredecessor to avoid circular import
|
|
@@ -69,32 +70,31 @@ class FeatureExtractor(FitableFeature):
|
|
|
69
70
|
assert type(self) in pe_type
|
|
70
71
|
return feature_extractors
|
|
71
72
|
|
|
72
|
-
def
|
|
73
|
-
|
|
73
|
+
def _check_is_trainable(self, feature_extractors):
|
|
74
|
+
is_trainable = False
|
|
74
75
|
for fname, f in feature_extractors.items():
|
|
75
76
|
if isinstance(f, FeatureExtractor):
|
|
76
|
-
|
|
77
|
+
is_trainable = f._is_trainable
|
|
77
78
|
else:
|
|
78
79
|
f = _get_underlying_func(f)
|
|
79
|
-
if isinstance(f,
|
|
80
|
-
|
|
81
|
-
if
|
|
80
|
+
if isinstance(f, TrainableFeature):
|
|
81
|
+
is_trainable = True
|
|
82
|
+
if is_trainable:
|
|
82
83
|
break
|
|
83
|
-
return
|
|
84
|
+
return is_trainable
|
|
84
85
|
|
|
85
86
|
def preprocess(self, *x, **kwargs):
|
|
86
87
|
return (*x,)
|
|
87
88
|
|
|
88
|
-
def feature_channel_names(self, ch_names):
|
|
89
|
-
return [""]
|
|
90
|
-
|
|
91
89
|
def __call__(self, *x, _batch_size=None, _ch_names=None):
|
|
92
90
|
assert _batch_size is not None
|
|
93
91
|
assert _ch_names is not None
|
|
94
|
-
if self.
|
|
92
|
+
if self._is_trainable:
|
|
95
93
|
super().__call__()
|
|
96
94
|
results_dict = dict()
|
|
97
95
|
z = self.preprocess(*x, **self.preprocess_kwargs)
|
|
96
|
+
if not isinstance(z, tuple):
|
|
97
|
+
z = (z,)
|
|
98
98
|
for fname, f in self.feature_extractors_dict.items():
|
|
99
99
|
if isinstance(f, FeatureExtractor):
|
|
100
100
|
r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
|
|
@@ -125,28 +125,28 @@ class FeatureExtractor(FitableFeature):
|
|
|
125
125
|
results_dict[name] = value
|
|
126
126
|
|
|
127
127
|
def clear(self):
|
|
128
|
-
if not self.
|
|
128
|
+
if not self._is_trainable:
|
|
129
129
|
return
|
|
130
130
|
for fname, f in self.feature_extractors_dict.items():
|
|
131
131
|
f = _get_underlying_func(f)
|
|
132
|
-
if isinstance(f,
|
|
132
|
+
if isinstance(f, TrainableFeature):
|
|
133
133
|
f.clear()
|
|
134
134
|
|
|
135
135
|
def partial_fit(self, *x, y=None):
|
|
136
|
-
if not self.
|
|
136
|
+
if not self._is_trainable:
|
|
137
137
|
return
|
|
138
138
|
z = self.preprocess(*x, **self.preprocess_kwargs)
|
|
139
139
|
for fname, f in self.feature_extractors_dict.items():
|
|
140
140
|
f = _get_underlying_func(f)
|
|
141
|
-
if isinstance(f,
|
|
141
|
+
if isinstance(f, TrainableFeature):
|
|
142
142
|
f.partial_fit(*z, y=y)
|
|
143
143
|
|
|
144
144
|
def fit(self):
|
|
145
|
-
if not self.
|
|
145
|
+
if not self._is_trainable:
|
|
146
146
|
return
|
|
147
147
|
for fname, f in self.feature_extractors_dict.items():
|
|
148
148
|
f = _get_underlying_func(f)
|
|
149
|
-
if isinstance(f,
|
|
149
|
+
if isinstance(f, TrainableFeature):
|
|
150
150
|
f.fit()
|
|
151
151
|
super().fit()
|
|
152
152
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import numba as nb
|
|
2
|
+
import numpy as np
|
|
3
3
|
from sklearn.neighbors import KDTree
|
|
4
4
|
|
|
5
|
-
from ..extractors import FeatureExtractor
|
|
6
5
|
from ..decorators import FeaturePredecessor, univariate_feature
|
|
7
|
-
|
|
6
|
+
from ..extractors import FeatureExtractor
|
|
7
|
+
from .signal import SIGNAL_PREDECESSORS
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"EntropyFeatureExtractor",
|
|
@@ -29,6 +29,7 @@ def _channel_app_samp_entropy_counts(x, m, r, l):
|
|
|
29
29
|
return kdtree.query_radius(x_emb, r, count_only=True)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
@FeaturePredecessor(*SIGNAL_PREDECESSORS)
|
|
32
33
|
class EntropyFeatureExtractor(FeatureExtractor):
|
|
33
34
|
def preprocess(self, x, m=2, r=0.2, l=1):
|
|
34
35
|
rr = r * x.std(axis=-1)
|
|
@@ -56,6 +57,7 @@ def complexity_sample_entropy(counts_m, counts_mp1):
|
|
|
56
57
|
return -np.log(A / B)
|
|
57
58
|
|
|
58
59
|
|
|
60
|
+
@FeaturePredecessor(*SIGNAL_PREDECESSORS)
|
|
59
61
|
@univariate_feature
|
|
60
62
|
def complexity_svd_entropy(x, m=10, tau=1):
|
|
61
63
|
x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
|
|
@@ -66,6 +68,7 @@ def complexity_svd_entropy(x, m=10, tau=1):
|
|
|
66
68
|
return -np.sum(s * np.log(s), axis=-1)
|
|
67
69
|
|
|
68
70
|
|
|
71
|
+
@FeaturePredecessor(*SIGNAL_PREDECESSORS)
|
|
69
72
|
@univariate_feature
|
|
70
73
|
@nb.njit(cache=True, fastmath=True)
|
|
71
74
|
def complexity_lempel_ziv(x, threshold=None):
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from itertools import chain
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
4
|
from scipy.signal import csd
|
|
4
5
|
|
|
5
|
-
from ..extractors import FeatureExtractor, BivariateFeature
|
|
6
6
|
from ..decorators import FeaturePredecessor, bivariate_feature
|
|
7
|
-
|
|
7
|
+
from ..extractors import BivariateFeature, FeatureExtractor
|
|
8
|
+
from . import utils
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
"CoherenceFeatureExtractor",
|
|
@@ -18,82 +19,41 @@ class CoherenceFeatureExtractor(FeatureExtractor):
|
|
|
18
19
|
def preprocess(self, x, **kwargs):
|
|
19
20
|
f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
|
|
20
21
|
f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
|
|
22
|
+
assert "fs" in kwargs and "nperseg" in kwargs
|
|
21
23
|
kwargs["axis"] = -1
|
|
22
24
|
n = x.shape[1]
|
|
23
25
|
idx_x, idx_y = BivariateFeature.get_pair_iterators(n)
|
|
24
26
|
ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y))
|
|
25
27
|
f, s = csd(x[:, ix], x[:, iy], **kwargs)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
sx, sxy = np.split(s, [n], axis=1)
|
|
33
|
-
sxx, syy = sx[:, idx_x].real, sx[:, idx_y].real
|
|
28
|
+
f_min, f_max = utils.get_valid_freq_band(
|
|
29
|
+
kwargs["fs"], x.shape[-1], f_min, f_max
|
|
30
|
+
)
|
|
31
|
+
f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max)
|
|
32
|
+
p, sxy = np.split(s, [n], axis=1)
|
|
33
|
+
sxx, syy = p[:, idx_x].real, p[:, idx_y].real
|
|
34
34
|
c = sxy / np.sqrt(sxx * syy)
|
|
35
35
|
return f, c
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
def _avg_over_bands(f, x, bands):
|
|
39
|
-
bands_avg = dict()
|
|
40
|
-
for k, v in bands.items():
|
|
41
|
-
assert isinstance(k, str)
|
|
42
|
-
assert isinstance(v, tuple)
|
|
43
|
-
assert len(v) == 2
|
|
44
|
-
assert v[0] < v[1]
|
|
45
|
-
mask = np.logical_and(f > v[0], f < v[1])
|
|
46
|
-
avg = x[..., mask].mean(axis=-1)
|
|
47
|
-
bands_avg[k] = avg
|
|
48
|
-
return bands_avg
|
|
49
|
-
|
|
50
|
-
|
|
51
38
|
@FeaturePredecessor(CoherenceFeatureExtractor)
|
|
52
39
|
@bivariate_feature
|
|
53
|
-
def connectivity_magnitude_square_coherence(
|
|
54
|
-
f,
|
|
55
|
-
c,
|
|
56
|
-
bands={
|
|
57
|
-
"delta": (1, 4.5),
|
|
58
|
-
"theta": (4.5, 8),
|
|
59
|
-
"alpha": (8, 12),
|
|
60
|
-
"beta": (12, 30),
|
|
61
|
-
},
|
|
62
|
-
):
|
|
40
|
+
def connectivity_magnitude_square_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
|
|
63
41
|
# https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
|
|
64
42
|
coher = c.real**2 + c.imag**2
|
|
65
|
-
return
|
|
43
|
+
return utils.reduce_freq_bands(f, coher, bands, np.mean)
|
|
66
44
|
|
|
67
45
|
|
|
68
46
|
@FeaturePredecessor(CoherenceFeatureExtractor)
|
|
69
47
|
@bivariate_feature
|
|
70
|
-
def connectivity_imaginary_coherence(
|
|
71
|
-
f,
|
|
72
|
-
c,
|
|
73
|
-
bands={
|
|
74
|
-
"delta": (1, 4.5),
|
|
75
|
-
"theta": (4.5, 8),
|
|
76
|
-
"alpha": (8, 12),
|
|
77
|
-
"beta": (12, 30),
|
|
78
|
-
},
|
|
79
|
-
):
|
|
48
|
+
def connectivity_imaginary_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
|
|
80
49
|
# https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
|
|
81
50
|
coher = c.imag
|
|
82
|
-
return
|
|
51
|
+
return utils.reduce_freq_bands(f, coher, bands, np.mean)
|
|
83
52
|
|
|
84
53
|
|
|
85
54
|
@FeaturePredecessor(CoherenceFeatureExtractor)
|
|
86
55
|
@bivariate_feature
|
|
87
|
-
def connectivity_lagged_coherence(
|
|
88
|
-
f,
|
|
89
|
-
c,
|
|
90
|
-
bands={
|
|
91
|
-
"delta": (1, 4.5),
|
|
92
|
-
"theta": (4.5, 8),
|
|
93
|
-
"alpha": (8, 12),
|
|
94
|
-
"beta": (12, 30),
|
|
95
|
-
},
|
|
96
|
-
):
|
|
56
|
+
def connectivity_lagged_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS):
|
|
97
57
|
# https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity
|
|
98
58
|
coher = c.imag / np.sqrt(1 - c.real)
|
|
99
|
-
return
|
|
59
|
+
return utils.reduce_freq_bands(f, coher, bands, np.mean)
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import numba as nb
|
|
2
|
+
import numpy as np
|
|
3
3
|
import scipy
|
|
4
4
|
import scipy.linalg
|
|
5
5
|
|
|
6
|
-
from ..extractors import FitableFeature
|
|
7
6
|
from ..decorators import multivariate_feature
|
|
8
|
-
|
|
7
|
+
from ..extractors import TrainableFeature
|
|
9
8
|
|
|
10
9
|
__all__ = [
|
|
11
10
|
"CommonSpatialPattern",
|
|
@@ -23,7 +22,7 @@ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov):
|
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
@multivariate_feature
|
|
26
|
-
class CommonSpatialPattern(
|
|
25
|
+
class CommonSpatialPattern(TrainableFeature):
|
|
27
26
|
def __init__(self):
|
|
28
27
|
super().__init__()
|
|
29
28
|
|