eegdash 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,453 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import json
4
+ import shutil
5
+ import warnings
6
+ from typing import Dict, no_type_check
7
+ from collections.abc import Callable, Iterable
8
+ import numpy as np
9
+ import pandas as pd
10
+ from joblib import Parallel, delayed
11
+ from braindecode.datasets.base import (
12
+ EEGWindowsDataset,
13
+ BaseConcatDataset,
14
+ _create_description,
15
+ )
16
+
17
+
18
+ class FeaturesDataset(EEGWindowsDataset):
19
+ """Returns samples from a pandas DataFrame object along with a target.
20
+
21
+ Dataset which serves samples from a pandas DataFrame object along with a
22
+ target. The target is unique for the dataset, and is obtained through the
23
+ `description` attribute.
24
+
25
+ Parameters
26
+ ----------
27
+ features : a pandas DataFrame
28
+ Tabular data.
29
+ description : dict | pandas.Series | None
30
+ Holds additional description about the continuous signal / subject.
31
+ transform : callable | None
32
+ On-the-fly transform applied to the example before it is returned.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ features: pd.DataFrame,
38
+ metadata: pd.DataFrame | None = None,
39
+ description: dict | pd.Series | None = None,
40
+ transform: Callable | None = None,
41
+ raw_info: Dict | None = None,
42
+ raw_preproc_kwargs: Dict | None = None,
43
+ window_kwargs: Dict | None = None,
44
+ window_preproc_kwargs: Dict | None = None,
45
+ features_kwargs: Dict | None = None,
46
+ ):
47
+ self.features = features
48
+ self.n_features = features.columns.size
49
+ self.metadata = metadata
50
+ self._description = _create_description(description)
51
+ self.transform = transform
52
+ self.raw_info = raw_info
53
+ self.raw_preproc_kwargs = raw_preproc_kwargs
54
+ self.window_kwargs = window_kwargs
55
+ self.window_preproc_kwargs = window_preproc_kwargs
56
+ self.features_kwargs = features_kwargs
57
+
58
+ self.crop_inds = metadata.loc[
59
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
60
+ ].to_numpy()
61
+ self.y = metadata.loc[:, "target"].to_list()
62
+
63
+ def __getitem__(self, index):
64
+ crop_inds = self.crop_inds[index].tolist()
65
+ X = self.features.iloc[index].to_numpy()
66
+ X = X.copy()
67
+ X.astype("float32")
68
+ if self.transform is not None:
69
+ X = self.transform(X)
70
+ y = self.y[index]
71
+ return X, y, crop_inds
72
+
73
+ def __len__(self):
74
+ return len(self.features.index)
75
+
76
+
77
+ def _compute_stats(
78
+ ds: FeaturesDataset,
79
+ return_count=False,
80
+ return_mean=False,
81
+ return_var=False,
82
+ ddof=1,
83
+ numeric_only=False,
84
+ ):
85
+ res = []
86
+ if return_count:
87
+ res.append(ds.features.count(numeric_only=numeric_only))
88
+ if return_mean:
89
+ res.append(ds.features.mean(numeric_only=numeric_only))
90
+ if return_var:
91
+ res.append(ds.features.var(ddof=ddof, numeric_only=numeric_only))
92
+ return tuple(res)
93
+
94
+
95
+ def _pooled_var(counts, means, variances, ddof):
96
+ count = counts.sum(axis=0)
97
+ mean = np.sum((counts / count) * means, axis=0)
98
+ var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
99
+ var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
100
+ var[:] -= (count / (count - ddof)) * (mean**2)
101
+ var[:] = var.clip(min=0)
102
+ return count, mean, var
103
+
104
+
105
+ class FeaturesConcatDataset(BaseConcatDataset):
106
+ """A base class for concatenated datasets.
107
+
108
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
109
+ a pandas DataFrame with additional description.
110
+
111
+ Parameters
112
+ ----------
113
+ list_of_ds : list
114
+ list of BaseDataset, BaseConcatDataset or WindowsDataset
115
+ target_transform : callable | None
116
+ Optional function to call on targets before returning them.
117
+
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ list_of_ds: list[FeaturesDataset] | None = None,
123
+ target_transform: Callable | None = None,
124
+ ):
125
+ # if we get a list of FeaturesConcatDataset, get all the individual datasets
126
+ if list_of_ds and isinstance(list_of_ds[0], FeaturesConcatDataset):
127
+ list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
128
+ super().__init__(list_of_ds)
129
+
130
+ self.target_transform = target_transform
131
+
132
+ def split(
133
+ self,
134
+ by: str | list[int] | list[list[int]] | dict[str, list[int]],
135
+ ) -> dict[str, FeaturesConcatDataset]:
136
+ """Split the dataset based on information listed in its description.
137
+
138
+ The format could be based on a DataFrame or based on indices.
139
+
140
+ Parameters
141
+ ----------
142
+ by : str | list | dict
143
+ If ``by`` is a string, splitting is performed based on the
144
+ description DataFrame column with this name.
145
+ If ``by`` is a (list of) list of integers, the position in the first
146
+ list corresponds to the split id and the integers to the
147
+ datapoints of that split.
148
+ If a dict then each key will be used in the returned
149
+ splits dict and each value should be a list of int.
150
+
151
+ Returns
152
+ -------
153
+ splits : dict
154
+ A dictionary with the name of the split (a string) as key and the
155
+ dataset as value.
156
+ """
157
+ if isinstance(by, str):
158
+ split_ids = {
159
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
160
+ }
161
+ elif isinstance(by, dict):
162
+ split_ids = by
163
+ else:
164
+ # assume list(int)
165
+ if not isinstance(by[0], list):
166
+ by = [by]
167
+ # assume list(list(int))
168
+ split_ids = {split_i: split for split_i, split in enumerate(by)}
169
+
170
+ return {
171
+ str(split_name): FeaturesConcatDataset(
172
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
173
+ target_transform=self.target_transform,
174
+ )
175
+ for split_name, ds_inds in split_ids.items()
176
+ }
177
+
178
+ def get_metadata(self) -> pd.DataFrame:
179
+ """Concatenate the metadata and description of the wrapped Epochs.
180
+
181
+ Returns
182
+ -------
183
+ metadata : pd.DataFrame
184
+ DataFrame containing as many rows as there are windows in the
185
+ BaseConcatDataset, with the metadata and description information
186
+ for each window.
187
+ """
188
+ if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
189
+ raise TypeError(
190
+ "Metadata dataframe can only be computed when all "
191
+ "datasets are FeaturesDataset."
192
+ )
193
+
194
+ all_dfs = list()
195
+ for ds in self.datasets:
196
+ df = ds.metadata
197
+ for k, v in ds.description.items():
198
+ df[k] = v
199
+ all_dfs.append(df)
200
+
201
+ return pd.concat(all_dfs)
202
+
203
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
204
+ """Save datasets to files by creating one subdirectory for each dataset:
205
+ path/
206
+ 0/
207
+ 0-feat.parquet
208
+ metadata_df.pkl
209
+ description.json
210
+ raw-info.fif (if raw info was saved)
211
+ raw_preproc_kwargs.json (if raws were preprocessed)
212
+ window_kwargs.json (if this is a windowed dataset)
213
+ window_preproc_kwargs.json (if windows were preprocessed)
214
+ features_kwargs.json
215
+ 1/
216
+ 1-feat.parquet
217
+ metadata_df.pkl
218
+ description.json
219
+ raw-info.fif (if raw info was saved)
220
+ raw_preproc_kwargs.json (if raws were preprocessed)
221
+ window_kwargs.json (if this is a windowed dataset)
222
+ window_preproc_kwargs.json (if windows were preprocessed)
223
+ features_kwargs.json
224
+
225
+ Parameters
226
+ ----------
227
+ path : str
228
+ Directory in which subdirectories are created to store
229
+ -feat.parquet and .json files to.
230
+ overwrite : bool
231
+ Whether to delete old subdirectories that will be saved to in this
232
+ call.
233
+ offset : int
234
+ If provided, the integer is added to the id of the dataset in the
235
+ concat. This is useful in the setting of very large datasets, where
236
+ one dataset has to be processed and saved at a time to account for
237
+ its original position.
238
+ """
239
+ if len(self.datasets) == 0:
240
+ raise ValueError("Expect at least one dataset")
241
+ path_contents = os.listdir(path)
242
+ n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
243
+ for i_ds, ds in enumerate(self.datasets):
244
+ # remove subdirectory from list of untouched files / subdirectories
245
+ if str(i_ds + offset) in path_contents:
246
+ path_contents.remove(str(i_ds + offset))
247
+ # save_dir/i_ds/
248
+ sub_dir = os.path.join(path, str(i_ds + offset))
249
+ if os.path.exists(sub_dir):
250
+ if overwrite:
251
+ shutil.rmtree(sub_dir)
252
+ else:
253
+ raise FileExistsError(
254
+ f"Subdirectory {sub_dir} already exists. Please select"
255
+ f" a different directory, set overwrite=True, or "
256
+ f"resolve manually."
257
+ )
258
+ # save_dir/{i_ds+offset}/
259
+ os.makedirs(sub_dir)
260
+ # save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet
261
+ self._save_features(sub_dir, ds, i_ds, offset)
262
+ # save_dir/{i_ds+offset}/metadata_df.pkl
263
+ self._save_metadata(sub_dir, ds)
264
+ # save_dir/{i_ds+offset}/description.json
265
+ self._save_description(sub_dir, ds.description)
266
+ # save_dir/{i_ds+offset}/raw-info.fif
267
+ self._save_raw_info(sub_dir, ds)
268
+ # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
269
+ # save_dir/{i_ds+offset}/window_kwargs.json
270
+ # save_dir/{i_ds+offset}/window_preproc_kwargs.json
271
+ # save_dir/{i_ds+offset}/features_kwargs.json
272
+ self._save_kwargs(sub_dir, ds)
273
+ if overwrite:
274
+ # the following will be True for all datasets preprocessed and
275
+ # stored in parallel with braindecode.preprocessing.preprocess
276
+ if i_ds + 1 + offset < n_sub_dirs:
277
+ warnings.warn(
278
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
279
+ f"does not match the number of existing "
280
+ f"subdirectories ({n_sub_dirs}). You may now "
281
+ f"encounter a mix of differently preprocessed "
282
+ f"datasets!",
283
+ UserWarning,
284
+ )
285
+ # if path contains files or directories that were not touched, raise
286
+ # warning
287
+ if path_contents:
288
+ warnings.warn(
289
+ f"Chosen directory {path} contains other "
290
+ f"subdirectories or files {path_contents}."
291
+ )
292
+
293
+ @staticmethod
294
+ def _save_features(sub_dir, ds, i_ds, offset):
295
+ parquet_file_name = f"{i_ds + offset}-feat.parquet"
296
+ parquet_file_path = os.path.join(sub_dir, parquet_file_name)
297
+ ds.features.to_parquet(parquet_file_path)
298
+
299
+ @staticmethod
300
+ def _save_raw_info(sub_dir, ds):
301
+ if hasattr(ds, "raw_info"):
302
+ fif_file_name = "raw-info.fif"
303
+ fif_file_path = os.path.join(sub_dir, fif_file_name)
304
+ ds.raw_info.save(fif_file_path)
305
+
306
+ @staticmethod
307
+ def _save_kwargs(sub_dir, ds):
308
+ for kwargs_name in [
309
+ "raw_preproc_kwargs",
310
+ "window_kwargs",
311
+ "window_preproc_kwargs",
312
+ "features_kwargs",
313
+ ]:
314
+ if hasattr(ds, kwargs_name):
315
+ kwargs_file_name = ".".join([kwargs_name, "json"])
316
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
317
+ kwargs = getattr(ds, kwargs_name)
318
+ if kwargs is not None:
319
+ with open(kwargs_file_path, "w") as f:
320
+ json.dump(kwargs, f)
321
+
322
+ def to_dataframe(
323
+ self, include_metadata=False, include_target=False, include_crop_inds=False
324
+ ):
325
+ if include_metadata or (include_target and include_crop_inds):
326
+ dataframes = [
327
+ ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
328
+ for ds in self.datasets
329
+ ]
330
+ elif include_target:
331
+ dataframes = [
332
+ ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
333
+ for ds in self.datasets
334
+ ]
335
+ elif include_crop_inds:
336
+ dataframes = [
337
+ ds.metadata.drop("target", axis="columns").join(
338
+ ds.features, how="right", lsuffix="_metadata"
339
+ )
340
+ for ds in self.datasets
341
+ ]
342
+ else:
343
+ dataframes = [ds.features for ds in self.datasets]
344
+ return pd.concat(dataframes, axis=0, ignore_index=True)
345
+
346
+ def _numeric_columns(self):
347
+ return self.datasets[0].features.select_dtypes(include=np.number).columns
348
+
349
+ def count(self, numeric_only=False, n_jobs=1):
350
+ stats = Parallel(n_jobs)(
351
+ delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only)
352
+ for ds in self.datasets
353
+ )
354
+ counts = np.array([s[0] for s in stats])
355
+ count = counts.sum(axis=0)
356
+ return pd.Series(count, index=self._numeric_columns())
357
+
358
+ def mean(self, numeric_only=False, n_jobs=1):
359
+ stats = Parallel(n_jobs)(
360
+ delayed(_compute_stats)(
361
+ ds, return_count=True, return_mean=True, numeric_only=numeric_only
362
+ )
363
+ for ds in self.datasets
364
+ )
365
+ counts, means = np.array([s[0] for s in stats]), np.array([s[1] for s in stats])
366
+ count = counts.sum(axis=0, keepdims=True)
367
+ mean = np.sum((counts / count) * means, axis=0)
368
+ return pd.Series(mean, index=self._numeric_columns())
369
+
370
+ def var(self, ddof=1, numeric_only=False, n_jobs=1):
371
+ stats = Parallel(n_jobs)(
372
+ delayed(_compute_stats)(
373
+ ds,
374
+ return_count=True,
375
+ return_mean=True,
376
+ return_var=True,
377
+ ddof=ddof,
378
+ numeric_only=numeric_only,
379
+ )
380
+ for ds in self.datasets
381
+ )
382
+ counts, means, variances = (
383
+ np.array([s[0] for s in stats]),
384
+ np.array([s[1] for s in stats]),
385
+ np.array([s[2] for s in stats]),
386
+ )
387
+ _, _, var = _pooled_var(counts, means, variances, ddof)
388
+ return pd.Series(var, index=self._numeric_columns())
389
+
390
+ def std(self, ddof=1, numeric_only=False, n_jobs=1):
391
+ return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
392
+
393
+ def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
394
+ stats = Parallel(n_jobs)(
395
+ delayed(_compute_stats)(
396
+ ds,
397
+ return_count=True,
398
+ return_mean=True,
399
+ return_var=True,
400
+ ddof=ddof,
401
+ numeric_only=numeric_only,
402
+ )
403
+ for ds in self.datasets
404
+ )
405
+ counts, means, variances = (
406
+ np.array([s[0] for s in stats]),
407
+ np.array([s[1] for s in stats]),
408
+ np.array([s[2] for s in stats]),
409
+ )
410
+ _, mean, var = _pooled_var(counts, means, variances, ddof)
411
+ std = np.sqrt(var) + eps
412
+ for ds in self.datasets:
413
+ ds.features = (ds.features - mean) / std
414
+
415
+ @staticmethod
416
+ def _enforce_inplace_operations(func_name, kwargs):
417
+ if "inplace" in kwargs and kwargs["inplace"] is False:
418
+ raise ValueError(
419
+ f"{func_name} only works inplace, please change "
420
+ + "to inplace=True (default)."
421
+ )
422
+ kwargs["inplace"] = True
423
+
424
+ def fillna(self, *args, **kwargs):
425
+ FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs)
426
+ for ds in self.datasets:
427
+ ds.features.fillna(*args, **kwargs)
428
+
429
+ def replace(self, *args, **kwargs):
430
+ FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs)
431
+ for ds in self.datasets:
432
+ ds.features.replace(*args, **kwargs)
433
+
434
+ def interpolate(self, *args, **kwargs):
435
+ FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs)
436
+ for ds in self.datasets:
437
+ ds.features.interpolate(*args, **kwargs)
438
+
439
+ def dropna(self, *args, **kwargs):
440
+ FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs)
441
+ for ds in self.datasets:
442
+ ds.features.dropna(*args, **kwargs)
443
+
444
+ def drop(self, *args, **kwargs):
445
+ FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs)
446
+ for ds in self.datasets:
447
+ ds.features.drop(*args, **kwargs)
448
+
449
+ def join(self, concat_dataset: FeaturesConcatDataset, **kwargs):
450
+ assert len(self.datasets) == len(concat_dataset.datasets)
451
+ for ds1, ds2 in zip(self.datasets, concat_dataset.datasets):
452
+ assert len(ds1) == len(ds2)
453
+ ds1.features.join(ds2, **kwargs)
@@ -0,0 +1,43 @@
1
+ from typing import List, Type
2
+ from collections.abc import Callable
3
+
4
+ from .extractors import (
5
+ FeatureExtractor,
6
+ UnivariateFeature,
7
+ BivariateFeature,
8
+ DirectedBivariateFeature,
9
+ MultivariateFeature,
10
+ )
11
+ from .extractors import _get_underlying_func
12
+
13
+
14
+ class FeaturePredecessor:
15
+ def __init__(self, *parent_extractor_type: List[Type]):
16
+ parent_cls = parent_extractor_type
17
+ if not parent_cls:
18
+ parent_cls = [FeatureExtractor]
19
+ for p_cls in parent_cls:
20
+ assert issubclass(p_cls, FeatureExtractor)
21
+ self.parent_extractor_type = parent_cls
22
+
23
+ def __call__(self, func: Callable):
24
+ f = _get_underlying_func(func)
25
+ f.parent_extractor_type = self.parent_extractor_type
26
+ return func
27
+
28
+
29
+ class FeatureKind:
30
+ def __init__(self, feature_kind: MultivariateFeature):
31
+ self.feature_kind = feature_kind
32
+
33
+ def __call__(self, func):
34
+ f = _get_underlying_func(func)
35
+ f.feature_kind = self.feature_kind
36
+ return func
37
+
38
+
39
+ # Syntax sugar
40
+ univariate_feature = FeatureKind(UnivariateFeature())
41
+ bivariate_feature = FeatureKind(BivariateFeature())
42
+ directed_bivariate_feature = FeatureKind(DirectedBivariateFeature())
43
+ multivariate_feature = FeatureKind(MultivariateFeature())
@@ -0,0 +1,209 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ import numpy as np
6
+ from numba.core.dispatcher import Dispatcher
7
+
8
+
9
+ def _get_underlying_func(func):
10
+ f = func
11
+ if isinstance(f, partial):
12
+ f = f.func
13
+ if isinstance(f, Dispatcher):
14
+ f = f.py_func
15
+ return f
16
+
17
+
18
+ class FitableFeature(ABC):
19
+ def __init__(self):
20
+ self._is_fitted = False
21
+ self.clear()
22
+
23
+ @abstractmethod
24
+ def clear(self):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def partial_fit(self, *x, y=None):
29
+ pass
30
+
31
+ def fit(self):
32
+ self._is_fitted = True
33
+
34
+ def __call__(self, *args, **kwargs):
35
+ if not self._is_fitted:
36
+ raise RuntimeError(
37
+ f"{self.__class__} cannot be called, it has to be fitted first."
38
+ )
39
+
40
+
41
+ class FeatureExtractor(FitableFeature):
42
+ def __init__(
43
+ self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
44
+ ):
45
+ self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
46
+ self._is_fitable = self._check_is_fitable(feature_extractors)
47
+ super().__init__()
48
+
49
+ # bypassing FeaturePredecessor to avoid circular import
50
+ if not hasattr(self, "parent_extractor_type"):
51
+ self.parent_extractor_type = [FeatureExtractor]
52
+
53
+ self.preprocess_kwargs = preprocess_kwargs
54
+ if self.preprocess_kwargs is None:
55
+ self.preprocess_kwargs = dict()
56
+ self.features_kwargs = {
57
+ "preprocess_kwargs": preprocess_kwargs,
58
+ }
59
+ for fn, fe in feature_extractors.items():
60
+ if isinstance(fe, FeatureExtractor):
61
+ self.features_kwargs[fn] = fe.features_kwargs
62
+ if isinstance(fe, partial):
63
+ self.features_kwargs[fn] = fe.keywords
64
+
65
+ def _validate_execution_tree(self, feature_extractors):
66
+ for fname, f in feature_extractors.items():
67
+ f = _get_underlying_func(f)
68
+ pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
69
+ assert type(self) in pe_type
70
+ return feature_extractors
71
+
72
+ def _check_is_fitable(self, feature_extractors):
73
+ is_fitable = False
74
+ for fname, f in feature_extractors.items():
75
+ if isinstance(f, FeatureExtractor):
76
+ is_fitable = f._is_fitable
77
+ else:
78
+ f = _get_underlying_func(f)
79
+ if isinstance(f, FitableFeature):
80
+ is_fitable = True
81
+ if is_fitable:
82
+ break
83
+ return is_fitable
84
+
85
+ def preprocess(self, *x, **kwargs):
86
+ return (*x,)
87
+
88
+ def feature_channel_names(self, ch_names):
89
+ return [""]
90
+
91
+ def __call__(self, *x, _batch_size=None, _ch_names=None):
92
+ assert _batch_size is not None
93
+ assert _ch_names is not None
94
+ if self._is_fitable:
95
+ super().__call__()
96
+ results_dict = dict()
97
+ z = self.preprocess(*x, **self.preprocess_kwargs)
98
+ for fname, f in self.feature_extractors_dict.items():
99
+ if isinstance(f, FeatureExtractor):
100
+ r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
101
+ else:
102
+ r = f(*z)
103
+ f = _get_underlying_func(f)
104
+ if hasattr(f, "feature_kind"):
105
+ r = f.feature_kind(r, _ch_names=_ch_names)
106
+ if not isinstance(fname, str) or not fname:
107
+ if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
108
+ fname = ""
109
+ else:
110
+ fname = f.__name__
111
+ if isinstance(r, dict):
112
+ if fname:
113
+ fname += "_"
114
+ for k, v in r.items():
115
+ self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
116
+ else:
117
+ self._add_feature_to_dict(results_dict, fname, r, _batch_size)
118
+ return results_dict
119
+
120
+ def _add_feature_to_dict(self, results_dict, name, value, batch_size):
121
+ if not isinstance(value, np.ndarray):
122
+ results_dict[name] = value
123
+ else:
124
+ assert value.shape[0] == batch_size
125
+ results_dict[name] = value
126
+
127
+ def clear(self):
128
+ if not self._is_fitable:
129
+ return
130
+ for fname, f in self.feature_extractors_dict.items():
131
+ f = _get_underlying_func(f)
132
+ if isinstance(f, FitableFeature):
133
+ f.clear()
134
+
135
+ def partial_fit(self, *x, y=None):
136
+ if not self._is_fitable:
137
+ return
138
+ z = self.preprocess(*x, **self.preprocess_kwargs)
139
+ for fname, f in self.feature_extractors_dict.items():
140
+ f = _get_underlying_func(f)
141
+ if isinstance(f, FitableFeature):
142
+ f.partial_fit(*z, y=y)
143
+
144
+ def fit(self):
145
+ if not self._is_fitable:
146
+ return
147
+ for fname, f in self.feature_extractors_dict.items():
148
+ f = _get_underlying_func(f)
149
+ if isinstance(f, FitableFeature):
150
+ f.fit()
151
+ super().fit()
152
+
153
+
154
+ class MultivariateFeature:
155
+ def __call__(self, x, _ch_names=None):
156
+ assert _ch_names is not None
157
+ f_channels = self.feature_channel_names(_ch_names)
158
+ if isinstance(x, dict):
159
+ r = dict()
160
+ for k, v in x.items():
161
+ r.update(self._array_to_dict(v, f_channels, k))
162
+ return r
163
+ return self._array_to_dict(x, f_channels)
164
+
165
+ @staticmethod
166
+ def _array_to_dict(x, f_channels, name=""):
167
+ assert isinstance(x, np.ndarray)
168
+ if len(f_channels) == 0:
169
+ assert x.ndim == 1
170
+ if name:
171
+ return {name: x}
172
+ return x
173
+ assert x.shape[1] == len(f_channels)
174
+ x = x.swapaxes(0, 1)
175
+ names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
176
+ return dict(zip(names, x))
177
+
178
+ def feature_channel_names(self, ch_names):
179
+ return []
180
+
181
+
182
+ class UnivariateFeature(MultivariateFeature):
183
+ def feature_channel_names(self, ch_names):
184
+ return ch_names
185
+
186
+
187
+ class BivariateFeature(MultivariateFeature):
188
+ def __init__(self, *args, channel_pair_format="{}<>{}"):
189
+ super().__init__(*args)
190
+ self.channel_pair_format = channel_pair_format
191
+
192
+ @staticmethod
193
+ def get_pair_iterators(n):
194
+ return np.triu_indices(n, 1)
195
+
196
+ def feature_channel_names(self, ch_names):
197
+ return [
198
+ self.channel_pair_format.format(ch_names[i], ch_names[j])
199
+ for i, j in zip(*self.get_pair_iterators(len(ch_names)))
200
+ ]
201
+
202
+
203
+ class DirectedBivariateFeature(BivariateFeature):
204
+ @staticmethod
205
+ def get_pair_iterators(n):
206
+ return [
207
+ np.append(a, b)
208
+ for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
209
+ ]