eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,456 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import shutil
6
+ import warnings
7
+ from collections.abc import Callable, Iterable
8
+ from typing import Dict, no_type_check
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from joblib import Parallel, delayed
13
+
14
+ from braindecode.datasets.base import (
15
+ BaseConcatDataset,
16
+ EEGWindowsDataset,
17
+ _create_description,
18
+ )
19
+
20
+
21
+ class FeaturesDataset(EEGWindowsDataset):
22
+ """Returns samples from a pandas DataFrame object along with a target.
23
+
24
+ Dataset which serves samples from a pandas DataFrame object along with a
25
+ target. The target is unique for the dataset, and is obtained through the
26
+ `description` attribute.
27
+
28
+ Parameters
29
+ ----------
30
+ features : a pandas DataFrame
31
+ Tabular data.
32
+ description : dict | pandas.Series | None
33
+ Holds additional description about the continuous signal / subject.
34
+ transform : callable | None
35
+ On-the-fly transform applied to the example before it is returned.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ features: pd.DataFrame,
41
+ metadata: pd.DataFrame | None = None,
42
+ description: dict | pd.Series | None = None,
43
+ transform: Callable | None = None,
44
+ raw_info: Dict | None = None,
45
+ raw_preproc_kwargs: Dict | None = None,
46
+ window_kwargs: Dict | None = None,
47
+ window_preproc_kwargs: Dict | None = None,
48
+ features_kwargs: Dict | None = None,
49
+ ):
50
+ self.features = features
51
+ self.n_features = features.columns.size
52
+ self.metadata = metadata
53
+ self._description = _create_description(description)
54
+ self.transform = transform
55
+ self.raw_info = raw_info
56
+ self.raw_preproc_kwargs = raw_preproc_kwargs
57
+ self.window_kwargs = window_kwargs
58
+ self.window_preproc_kwargs = window_preproc_kwargs
59
+ self.features_kwargs = features_kwargs
60
+
61
+ self.crop_inds = metadata.loc[
62
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
63
+ ].to_numpy()
64
+ self.y = metadata.loc[:, "target"].to_list()
65
+
66
+ def __getitem__(self, index):
67
+ crop_inds = self.crop_inds[index].tolist()
68
+ X = self.features.iloc[index].to_numpy()
69
+ X = X.copy()
70
+ X.astype("float32")
71
+ if self.transform is not None:
72
+ X = self.transform(X)
73
+ y = self.y[index]
74
+ return X, y, crop_inds
75
+
76
+ def __len__(self):
77
+ return len(self.features.index)
78
+
79
+
80
+ def _compute_stats(
81
+ ds: FeaturesDataset,
82
+ return_count=False,
83
+ return_mean=False,
84
+ return_var=False,
85
+ ddof=1,
86
+ numeric_only=False,
87
+ ):
88
+ res = []
89
+ if return_count:
90
+ res.append(ds.features.count(numeric_only=numeric_only))
91
+ if return_mean:
92
+ res.append(ds.features.mean(numeric_only=numeric_only))
93
+ if return_var:
94
+ res.append(ds.features.var(ddof=ddof, numeric_only=numeric_only))
95
+ return tuple(res)
96
+
97
+
98
+ def _pooled_var(counts, means, variances, ddof):
99
+ count = counts.sum(axis=0)
100
+ mean = np.sum((counts / count) * means, axis=0)
101
+ var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
102
+ var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
103
+ var[:] -= (count / (count - ddof)) * (mean**2)
104
+ var[:] = var.clip(min=0)
105
+ return count, mean, var
106
+
107
+
108
+ class FeaturesConcatDataset(BaseConcatDataset):
109
+ """A base class for concatenated datasets.
110
+
111
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
112
+ a pandas DataFrame with additional description.
113
+
114
+ Parameters
115
+ ----------
116
+ list_of_ds : list
117
+ list of BaseDataset, BaseConcatDataset or WindowsDataset
118
+ target_transform : callable | None
119
+ Optional function to call on targets before returning them.
120
+
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ list_of_ds: list[FeaturesDataset] | None = None,
126
+ target_transform: Callable | None = None,
127
+ ):
128
+ # if we get a list of FeaturesConcatDataset, get all the individual datasets
129
+ if list_of_ds and isinstance(list_of_ds[0], FeaturesConcatDataset):
130
+ list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
131
+ super().__init__(list_of_ds)
132
+
133
+ self.target_transform = target_transform
134
+
135
+ def split(
136
+ self,
137
+ by: str | list[int] | list[list[int]] | dict[str, list[int]],
138
+ ) -> dict[str, FeaturesConcatDataset]:
139
+ """Split the dataset based on information listed in its description.
140
+
141
+ The format could be based on a DataFrame or based on indices.
142
+
143
+ Parameters
144
+ ----------
145
+ by : str | list | dict
146
+ If ``by`` is a string, splitting is performed based on the
147
+ description DataFrame column with this name.
148
+ If ``by`` is a (list of) list of integers, the position in the first
149
+ list corresponds to the split id and the integers to the
150
+ datapoints of that split.
151
+ If a dict then each key will be used in the returned
152
+ splits dict and each value should be a list of int.
153
+
154
+ Returns
155
+ -------
156
+ splits : dict
157
+ A dictionary with the name of the split (a string) as key and the
158
+ dataset as value.
159
+ """
160
+ if isinstance(by, str):
161
+ split_ids = {
162
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
163
+ }
164
+ elif isinstance(by, dict):
165
+ split_ids = by
166
+ else:
167
+ # assume list(int)
168
+ if not isinstance(by[0], list):
169
+ by = [by]
170
+ # assume list(list(int))
171
+ split_ids = {split_i: split for split_i, split in enumerate(by)}
172
+
173
+ return {
174
+ str(split_name): FeaturesConcatDataset(
175
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
176
+ target_transform=self.target_transform,
177
+ )
178
+ for split_name, ds_inds in split_ids.items()
179
+ }
180
+
181
+ def get_metadata(self) -> pd.DataFrame:
182
+ """Concatenate the metadata and description of the wrapped Epochs.
183
+
184
+ Returns
185
+ -------
186
+ metadata : pd.DataFrame
187
+ DataFrame containing as many rows as there are windows in the
188
+ BaseConcatDataset, with the metadata and description information
189
+ for each window.
190
+ """
191
+ if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
192
+ raise TypeError(
193
+ "Metadata dataframe can only be computed when all "
194
+ "datasets are FeaturesDataset."
195
+ )
196
+
197
+ all_dfs = list()
198
+ for ds in self.datasets:
199
+ df = ds.metadata
200
+ for k, v in ds.description.items():
201
+ df[k] = v
202
+ all_dfs.append(df)
203
+
204
+ return pd.concat(all_dfs)
205
+
206
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
207
+ """Save datasets to files by creating one subdirectory for each dataset:
208
+ path/
209
+ 0/
210
+ 0-feat.parquet
211
+ metadata_df.pkl
212
+ description.json
213
+ raw-info.fif (if raw info was saved)
214
+ raw_preproc_kwargs.json (if raws were preprocessed)
215
+ window_kwargs.json (if this is a windowed dataset)
216
+ window_preproc_kwargs.json (if windows were preprocessed)
217
+ features_kwargs.json
218
+ 1/
219
+ 1-feat.parquet
220
+ metadata_df.pkl
221
+ description.json
222
+ raw-info.fif (if raw info was saved)
223
+ raw_preproc_kwargs.json (if raws were preprocessed)
224
+ window_kwargs.json (if this is a windowed dataset)
225
+ window_preproc_kwargs.json (if windows were preprocessed)
226
+ features_kwargs.json
227
+
228
+ Parameters
229
+ ----------
230
+ path : str
231
+ Directory in which subdirectories are created to store
232
+ -feat.parquet and .json files to.
233
+ overwrite : bool
234
+ Whether to delete old subdirectories that will be saved to in this
235
+ call.
236
+ offset : int
237
+ If provided, the integer is added to the id of the dataset in the
238
+ concat. This is useful in the setting of very large datasets, where
239
+ one dataset has to be processed and saved at a time to account for
240
+ its original position.
241
+ """
242
+ if len(self.datasets) == 0:
243
+ raise ValueError("Expect at least one dataset")
244
+ path_contents = os.listdir(path)
245
+ n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
246
+ for i_ds, ds in enumerate(self.datasets):
247
+ # remove subdirectory from list of untouched files / subdirectories
248
+ if str(i_ds + offset) in path_contents:
249
+ path_contents.remove(str(i_ds + offset))
250
+ # save_dir/i_ds/
251
+ sub_dir = os.path.join(path, str(i_ds + offset))
252
+ if os.path.exists(sub_dir):
253
+ if overwrite:
254
+ shutil.rmtree(sub_dir)
255
+ else:
256
+ raise FileExistsError(
257
+ f"Subdirectory {sub_dir} already exists. Please select"
258
+ f" a different directory, set overwrite=True, or "
259
+ f"resolve manually."
260
+ )
261
+ # save_dir/{i_ds+offset}/
262
+ os.makedirs(sub_dir)
263
+ # save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet
264
+ self._save_features(sub_dir, ds, i_ds, offset)
265
+ # save_dir/{i_ds+offset}/metadata_df.pkl
266
+ self._save_metadata(sub_dir, ds)
267
+ # save_dir/{i_ds+offset}/description.json
268
+ self._save_description(sub_dir, ds.description)
269
+ # save_dir/{i_ds+offset}/raw-info.fif
270
+ self._save_raw_info(sub_dir, ds)
271
+ # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
272
+ # save_dir/{i_ds+offset}/window_kwargs.json
273
+ # save_dir/{i_ds+offset}/window_preproc_kwargs.json
274
+ # save_dir/{i_ds+offset}/features_kwargs.json
275
+ self._save_kwargs(sub_dir, ds)
276
+ if overwrite:
277
+ # the following will be True for all datasets preprocessed and
278
+ # stored in parallel with braindecode.preprocessing.preprocess
279
+ if i_ds + 1 + offset < n_sub_dirs:
280
+ warnings.warn(
281
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
282
+ f"does not match the number of existing "
283
+ f"subdirectories ({n_sub_dirs}). You may now "
284
+ f"encounter a mix of differently preprocessed "
285
+ f"datasets!",
286
+ UserWarning,
287
+ )
288
+ # if path contains files or directories that were not touched, raise
289
+ # warning
290
+ if path_contents:
291
+ warnings.warn(
292
+ f"Chosen directory {path} contains other "
293
+ f"subdirectories or files {path_contents}."
294
+ )
295
+
296
+ @staticmethod
297
+ def _save_features(sub_dir, ds, i_ds, offset):
298
+ parquet_file_name = f"{i_ds + offset}-feat.parquet"
299
+ parquet_file_path = os.path.join(sub_dir, parquet_file_name)
300
+ ds.features.to_parquet(parquet_file_path)
301
+
302
+ @staticmethod
303
+ def _save_raw_info(sub_dir, ds):
304
+ if hasattr(ds, "raw_info"):
305
+ fif_file_name = "raw-info.fif"
306
+ fif_file_path = os.path.join(sub_dir, fif_file_name)
307
+ ds.raw_info.save(fif_file_path)
308
+
309
+ @staticmethod
310
+ def _save_kwargs(sub_dir, ds):
311
+ for kwargs_name in [
312
+ "raw_preproc_kwargs",
313
+ "window_kwargs",
314
+ "window_preproc_kwargs",
315
+ "features_kwargs",
316
+ ]:
317
+ if hasattr(ds, kwargs_name):
318
+ kwargs_file_name = ".".join([kwargs_name, "json"])
319
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
320
+ kwargs = getattr(ds, kwargs_name)
321
+ if kwargs is not None:
322
+ with open(kwargs_file_path, "w") as f:
323
+ json.dump(kwargs, f)
324
+
325
+ def to_dataframe(
326
+ self, include_metadata=False, include_target=False, include_crop_inds=False
327
+ ):
328
+ if include_metadata or (include_target and include_crop_inds):
329
+ dataframes = [
330
+ ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
331
+ for ds in self.datasets
332
+ ]
333
+ elif include_target:
334
+ dataframes = [
335
+ ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
336
+ for ds in self.datasets
337
+ ]
338
+ elif include_crop_inds:
339
+ dataframes = [
340
+ ds.metadata.drop("target", axis="columns").join(
341
+ ds.features, how="right", lsuffix="_metadata"
342
+ )
343
+ for ds in self.datasets
344
+ ]
345
+ else:
346
+ dataframes = [ds.features for ds in self.datasets]
347
+ return pd.concat(dataframes, axis=0, ignore_index=True)
348
+
349
+ def _numeric_columns(self):
350
+ return self.datasets[0].features.select_dtypes(include=np.number).columns
351
+
352
+ def count(self, numeric_only=False, n_jobs=1):
353
+ stats = Parallel(n_jobs)(
354
+ delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only)
355
+ for ds in self.datasets
356
+ )
357
+ counts = np.array([s[0] for s in stats])
358
+ count = counts.sum(axis=0)
359
+ return pd.Series(count, index=self._numeric_columns())
360
+
361
+ def mean(self, numeric_only=False, n_jobs=1):
362
+ stats = Parallel(n_jobs)(
363
+ delayed(_compute_stats)(
364
+ ds, return_count=True, return_mean=True, numeric_only=numeric_only
365
+ )
366
+ for ds in self.datasets
367
+ )
368
+ counts, means = np.array([s[0] for s in stats]), np.array([s[1] for s in stats])
369
+ count = counts.sum(axis=0, keepdims=True)
370
+ mean = np.sum((counts / count) * means, axis=0)
371
+ return pd.Series(mean, index=self._numeric_columns())
372
+
373
+ def var(self, ddof=1, numeric_only=False, n_jobs=1):
374
+ stats = Parallel(n_jobs)(
375
+ delayed(_compute_stats)(
376
+ ds,
377
+ return_count=True,
378
+ return_mean=True,
379
+ return_var=True,
380
+ ddof=ddof,
381
+ numeric_only=numeric_only,
382
+ )
383
+ for ds in self.datasets
384
+ )
385
+ counts, means, variances = (
386
+ np.array([s[0] for s in stats]),
387
+ np.array([s[1] for s in stats]),
388
+ np.array([s[2] for s in stats]),
389
+ )
390
+ _, _, var = _pooled_var(counts, means, variances, ddof)
391
+ return pd.Series(var, index=self._numeric_columns())
392
+
393
+ def std(self, ddof=1, numeric_only=False, n_jobs=1):
394
+ return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
395
+
396
+ def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
397
+ stats = Parallel(n_jobs)(
398
+ delayed(_compute_stats)(
399
+ ds,
400
+ return_count=True,
401
+ return_mean=True,
402
+ return_var=True,
403
+ ddof=ddof,
404
+ numeric_only=numeric_only,
405
+ )
406
+ for ds in self.datasets
407
+ )
408
+ counts, means, variances = (
409
+ np.array([s[0] for s in stats]),
410
+ np.array([s[1] for s in stats]),
411
+ np.array([s[2] for s in stats]),
412
+ )
413
+ _, mean, var = _pooled_var(counts, means, variances, ddof)
414
+ std = np.sqrt(var) + eps
415
+ for ds in self.datasets:
416
+ ds.features = (ds.features - mean) / std
417
+
418
+ @staticmethod
419
+ def _enforce_inplace_operations(func_name, kwargs):
420
+ if "inplace" in kwargs and kwargs["inplace"] is False:
421
+ raise ValueError(
422
+ f"{func_name} only works inplace, please change "
423
+ + "to inplace=True (default)."
424
+ )
425
+ kwargs["inplace"] = True
426
+
427
+ def fillna(self, *args, **kwargs):
428
+ FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs)
429
+ for ds in self.datasets:
430
+ ds.features.fillna(*args, **kwargs)
431
+
432
+ def replace(self, *args, **kwargs):
433
+ FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs)
434
+ for ds in self.datasets:
435
+ ds.features.replace(*args, **kwargs)
436
+
437
+ def interpolate(self, *args, **kwargs):
438
+ FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs)
439
+ for ds in self.datasets:
440
+ ds.features.interpolate(*args, **kwargs)
441
+
442
+ def dropna(self, *args, **kwargs):
443
+ FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs)
444
+ for ds in self.datasets:
445
+ ds.features.dropna(*args, **kwargs)
446
+
447
+ def drop(self, *args, **kwargs):
448
+ FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs)
449
+ for ds in self.datasets:
450
+ ds.features.drop(*args, **kwargs)
451
+
452
+ def join(self, concat_dataset: FeaturesConcatDataset, **kwargs):
453
+ assert len(self.datasets) == len(concat_dataset.datasets)
454
+ for ds1, ds2 in zip(self.datasets, concat_dataset.datasets):
455
+ assert len(ds1) == len(ds2)
456
+ ds1.features.join(ds2, **kwargs)
@@ -0,0 +1,43 @@
1
+ from collections.abc import Callable
2
+ from typing import List, Type
3
+
4
+ from .extractors import (
5
+ BivariateFeature,
6
+ DirectedBivariateFeature,
7
+ FeatureExtractor,
8
+ MultivariateFeature,
9
+ UnivariateFeature,
10
+ _get_underlying_func,
11
+ )
12
+
13
+
14
+ class FeaturePredecessor:
15
+ def __init__(self, *parent_extractor_type: List[Type]):
16
+ parent_cls = parent_extractor_type
17
+ if not parent_cls:
18
+ parent_cls = [FeatureExtractor]
19
+ for p_cls in parent_cls:
20
+ assert issubclass(p_cls, FeatureExtractor)
21
+ self.parent_extractor_type = parent_cls
22
+
23
+ def __call__(self, func: Callable):
24
+ f = _get_underlying_func(func)
25
+ f.parent_extractor_type = self.parent_extractor_type
26
+ return func
27
+
28
+
29
+ class FeatureKind:
30
+ def __init__(self, feature_kind: MultivariateFeature):
31
+ self.feature_kind = feature_kind
32
+
33
+ def __call__(self, func):
34
+ f = _get_underlying_func(func)
35
+ f.feature_kind = self.feature_kind
36
+ return func
37
+
38
+
39
+ # Syntax sugar
40
+ univariate_feature = FeatureKind(UnivariateFeature())
41
+ bivariate_feature = FeatureKind(BivariateFeature())
42
+ directed_bivariate_feature = FeatureKind(DirectedBivariateFeature())
43
+ multivariate_feature = FeatureKind(MultivariateFeature())
@@ -0,0 +1,210 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable
3
+ from functools import partial
4
+ from typing import Dict
5
+
6
+ import numpy as np
7
+ from numba.core.dispatcher import Dispatcher
8
+
9
+
10
+ def _get_underlying_func(func):
11
+ f = func
12
+ if isinstance(f, partial):
13
+ f = f.func
14
+ if isinstance(f, Dispatcher):
15
+ f = f.py_func
16
+ return f
17
+
18
+
19
+ class FitableFeature(ABC):
20
+ def __init__(self):
21
+ self._is_fitted = False
22
+ self.clear()
23
+
24
+ @abstractmethod
25
+ def clear(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def partial_fit(self, *x, y=None):
30
+ pass
31
+
32
+ def fit(self):
33
+ self._is_fitted = True
34
+
35
+ def __call__(self, *args, **kwargs):
36
+ if not self._is_fitted:
37
+ raise RuntimeError(
38
+ f"{self.__class__} cannot be called, it has to be fitted first."
39
+ )
40
+
41
+
42
+ class FeatureExtractor(FitableFeature):
43
+ def __init__(
44
+ self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
45
+ ):
46
+ self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
47
+ self._is_fitable = self._check_is_fitable(feature_extractors)
48
+ super().__init__()
49
+
50
+ # bypassing FeaturePredecessor to avoid circular import
51
+ if not hasattr(self, "parent_extractor_type"):
52
+ self.parent_extractor_type = [FeatureExtractor]
53
+
54
+ self.preprocess_kwargs = preprocess_kwargs
55
+ if self.preprocess_kwargs is None:
56
+ self.preprocess_kwargs = dict()
57
+ self.features_kwargs = {
58
+ "preprocess_kwargs": preprocess_kwargs,
59
+ }
60
+ for fn, fe in feature_extractors.items():
61
+ if isinstance(fe, FeatureExtractor):
62
+ self.features_kwargs[fn] = fe.features_kwargs
63
+ if isinstance(fe, partial):
64
+ self.features_kwargs[fn] = fe.keywords
65
+
66
+ def _validate_execution_tree(self, feature_extractors):
67
+ for fname, f in feature_extractors.items():
68
+ f = _get_underlying_func(f)
69
+ pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
70
+ assert type(self) in pe_type
71
+ return feature_extractors
72
+
73
+ def _check_is_fitable(self, feature_extractors):
74
+ is_fitable = False
75
+ for fname, f in feature_extractors.items():
76
+ if isinstance(f, FeatureExtractor):
77
+ is_fitable = f._is_fitable
78
+ else:
79
+ f = _get_underlying_func(f)
80
+ if isinstance(f, FitableFeature):
81
+ is_fitable = True
82
+ if is_fitable:
83
+ break
84
+ return is_fitable
85
+
86
+ def preprocess(self, *x, **kwargs):
87
+ return (*x,)
88
+
89
+ def feature_channel_names(self, ch_names):
90
+ return [""]
91
+
92
+ def __call__(self, *x, _batch_size=None, _ch_names=None):
93
+ assert _batch_size is not None
94
+ assert _ch_names is not None
95
+ if self._is_fitable:
96
+ super().__call__()
97
+ results_dict = dict()
98
+ z = self.preprocess(*x, **self.preprocess_kwargs)
99
+ for fname, f in self.feature_extractors_dict.items():
100
+ if isinstance(f, FeatureExtractor):
101
+ r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
102
+ else:
103
+ r = f(*z)
104
+ f = _get_underlying_func(f)
105
+ if hasattr(f, "feature_kind"):
106
+ r = f.feature_kind(r, _ch_names=_ch_names)
107
+ if not isinstance(fname, str) or not fname:
108
+ if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
109
+ fname = ""
110
+ else:
111
+ fname = f.__name__
112
+ if isinstance(r, dict):
113
+ if fname:
114
+ fname += "_"
115
+ for k, v in r.items():
116
+ self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
117
+ else:
118
+ self._add_feature_to_dict(results_dict, fname, r, _batch_size)
119
+ return results_dict
120
+
121
+ def _add_feature_to_dict(self, results_dict, name, value, batch_size):
122
+ if not isinstance(value, np.ndarray):
123
+ results_dict[name] = value
124
+ else:
125
+ assert value.shape[0] == batch_size
126
+ results_dict[name] = value
127
+
128
+ def clear(self):
129
+ if not self._is_fitable:
130
+ return
131
+ for fname, f in self.feature_extractors_dict.items():
132
+ f = _get_underlying_func(f)
133
+ if isinstance(f, FitableFeature):
134
+ f.clear()
135
+
136
+ def partial_fit(self, *x, y=None):
137
+ if not self._is_fitable:
138
+ return
139
+ z = self.preprocess(*x, **self.preprocess_kwargs)
140
+ for fname, f in self.feature_extractors_dict.items():
141
+ f = _get_underlying_func(f)
142
+ if isinstance(f, FitableFeature):
143
+ f.partial_fit(*z, y=y)
144
+
145
+ def fit(self):
146
+ if not self._is_fitable:
147
+ return
148
+ for fname, f in self.feature_extractors_dict.items():
149
+ f = _get_underlying_func(f)
150
+ if isinstance(f, FitableFeature):
151
+ f.fit()
152
+ super().fit()
153
+
154
+
155
+ class MultivariateFeature:
156
+ def __call__(self, x, _ch_names=None):
157
+ assert _ch_names is not None
158
+ f_channels = self.feature_channel_names(_ch_names)
159
+ if isinstance(x, dict):
160
+ r = dict()
161
+ for k, v in x.items():
162
+ r.update(self._array_to_dict(v, f_channels, k))
163
+ return r
164
+ return self._array_to_dict(x, f_channels)
165
+
166
+ @staticmethod
167
+ def _array_to_dict(x, f_channels, name=""):
168
+ assert isinstance(x, np.ndarray)
169
+ if len(f_channels) == 0:
170
+ assert x.ndim == 1
171
+ if name:
172
+ return {name: x}
173
+ return x
174
+ assert x.shape[1] == len(f_channels)
175
+ x = x.swapaxes(0, 1)
176
+ names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
177
+ return dict(zip(names, x))
178
+
179
+ def feature_channel_names(self, ch_names):
180
+ return []
181
+
182
+
183
+ class UnivariateFeature(MultivariateFeature):
184
+ def feature_channel_names(self, ch_names):
185
+ return ch_names
186
+
187
+
188
+ class BivariateFeature(MultivariateFeature):
189
+ def __init__(self, *args, channel_pair_format="{}<>{}"):
190
+ super().__init__(*args)
191
+ self.channel_pair_format = channel_pair_format
192
+
193
+ @staticmethod
194
+ def get_pair_iterators(n):
195
+ return np.triu_indices(n, 1)
196
+
197
+ def feature_channel_names(self, ch_names):
198
+ return [
199
+ self.channel_pair_format.format(ch_names[i], ch_names[j])
200
+ for i, j in zip(*self.get_pair_iterators(len(ch_names)))
201
+ ]
202
+
203
+
204
+ class DirectedBivariateFeature(BivariateFeature):
205
+ @staticmethod
206
+ def get_pair_iterators(n):
207
+ return [
208
+ np.append(a, b)
209
+ for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
210
+ ]