eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +4 -1
- eegdash/data_config.py +28 -0
- eegdash/data_utils.py +193 -148
- eegdash/features/__init__.py +25 -0
- eegdash/features/datasets.py +456 -0
- eegdash/features/decorators.py +43 -0
- eegdash/features/extractors.py +210 -0
- eegdash/features/feature_bank/__init__.py +6 -0
- eegdash/features/feature_bank/complexity.py +96 -0
- eegdash/features/feature_bank/connectivity.py +59 -0
- eegdash/features/feature_bank/csp.py +101 -0
- eegdash/features/feature_bank/dimensionality.py +107 -0
- eegdash/features/feature_bank/signal.py +103 -0
- eegdash/features/feature_bank/spectral.py +116 -0
- eegdash/features/feature_bank/utils.py +48 -0
- eegdash/features/serialization.py +87 -0
- eegdash/features/utils.py +116 -0
- eegdash/main.py +250 -145
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/METADATA +26 -56
- eegdash-0.1.0.dist-info/RECORD +23 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/WHEEL +1 -1
- eegdash-0.0.8.dist-info/RECORD +0 -8
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import shutil
|
|
6
|
+
import warnings
|
|
7
|
+
from collections.abc import Callable, Iterable
|
|
8
|
+
from typing import Dict, no_type_check
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from joblib import Parallel, delayed
|
|
13
|
+
|
|
14
|
+
from braindecode.datasets.base import (
|
|
15
|
+
BaseConcatDataset,
|
|
16
|
+
EEGWindowsDataset,
|
|
17
|
+
_create_description,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FeaturesDataset(EEGWindowsDataset):
|
|
22
|
+
"""Returns samples from a pandas DataFrame object along with a target.
|
|
23
|
+
|
|
24
|
+
Dataset which serves samples from a pandas DataFrame object along with a
|
|
25
|
+
target. The target is unique for the dataset, and is obtained through the
|
|
26
|
+
`description` attribute.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
features : a pandas DataFrame
|
|
31
|
+
Tabular data.
|
|
32
|
+
description : dict | pandas.Series | None
|
|
33
|
+
Holds additional description about the continuous signal / subject.
|
|
34
|
+
transform : callable | None
|
|
35
|
+
On-the-fly transform applied to the example before it is returned.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
features: pd.DataFrame,
|
|
41
|
+
metadata: pd.DataFrame | None = None,
|
|
42
|
+
description: dict | pd.Series | None = None,
|
|
43
|
+
transform: Callable | None = None,
|
|
44
|
+
raw_info: Dict | None = None,
|
|
45
|
+
raw_preproc_kwargs: Dict | None = None,
|
|
46
|
+
window_kwargs: Dict | None = None,
|
|
47
|
+
window_preproc_kwargs: Dict | None = None,
|
|
48
|
+
features_kwargs: Dict | None = None,
|
|
49
|
+
):
|
|
50
|
+
self.features = features
|
|
51
|
+
self.n_features = features.columns.size
|
|
52
|
+
self.metadata = metadata
|
|
53
|
+
self._description = _create_description(description)
|
|
54
|
+
self.transform = transform
|
|
55
|
+
self.raw_info = raw_info
|
|
56
|
+
self.raw_preproc_kwargs = raw_preproc_kwargs
|
|
57
|
+
self.window_kwargs = window_kwargs
|
|
58
|
+
self.window_preproc_kwargs = window_preproc_kwargs
|
|
59
|
+
self.features_kwargs = features_kwargs
|
|
60
|
+
|
|
61
|
+
self.crop_inds = metadata.loc[
|
|
62
|
+
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
63
|
+
].to_numpy()
|
|
64
|
+
self.y = metadata.loc[:, "target"].to_list()
|
|
65
|
+
|
|
66
|
+
def __getitem__(self, index):
|
|
67
|
+
crop_inds = self.crop_inds[index].tolist()
|
|
68
|
+
X = self.features.iloc[index].to_numpy()
|
|
69
|
+
X = X.copy()
|
|
70
|
+
X.astype("float32")
|
|
71
|
+
if self.transform is not None:
|
|
72
|
+
X = self.transform(X)
|
|
73
|
+
y = self.y[index]
|
|
74
|
+
return X, y, crop_inds
|
|
75
|
+
|
|
76
|
+
def __len__(self):
|
|
77
|
+
return len(self.features.index)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _compute_stats(
|
|
81
|
+
ds: FeaturesDataset,
|
|
82
|
+
return_count=False,
|
|
83
|
+
return_mean=False,
|
|
84
|
+
return_var=False,
|
|
85
|
+
ddof=1,
|
|
86
|
+
numeric_only=False,
|
|
87
|
+
):
|
|
88
|
+
res = []
|
|
89
|
+
if return_count:
|
|
90
|
+
res.append(ds.features.count(numeric_only=numeric_only))
|
|
91
|
+
if return_mean:
|
|
92
|
+
res.append(ds.features.mean(numeric_only=numeric_only))
|
|
93
|
+
if return_var:
|
|
94
|
+
res.append(ds.features.var(ddof=ddof, numeric_only=numeric_only))
|
|
95
|
+
return tuple(res)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _pooled_var(counts, means, variances, ddof):
|
|
99
|
+
count = counts.sum(axis=0)
|
|
100
|
+
mean = np.sum((counts / count) * means, axis=0)
|
|
101
|
+
var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
|
|
102
|
+
var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
|
|
103
|
+
var[:] -= (count / (count - ddof)) * (mean**2)
|
|
104
|
+
var[:] = var.clip(min=0)
|
|
105
|
+
return count, mean, var
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class FeaturesConcatDataset(BaseConcatDataset):
|
|
109
|
+
"""A base class for concatenated datasets.
|
|
110
|
+
|
|
111
|
+
Holds either mne.Raw or mne.Epoch in self.datasets and has
|
|
112
|
+
a pandas DataFrame with additional description.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
list_of_ds : list
|
|
117
|
+
list of BaseDataset, BaseConcatDataset or WindowsDataset
|
|
118
|
+
target_transform : callable | None
|
|
119
|
+
Optional function to call on targets before returning them.
|
|
120
|
+
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
list_of_ds: list[FeaturesDataset] | None = None,
|
|
126
|
+
target_transform: Callable | None = None,
|
|
127
|
+
):
|
|
128
|
+
# if we get a list of FeaturesConcatDataset, get all the individual datasets
|
|
129
|
+
if list_of_ds and isinstance(list_of_ds[0], FeaturesConcatDataset):
|
|
130
|
+
list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
|
|
131
|
+
super().__init__(list_of_ds)
|
|
132
|
+
|
|
133
|
+
self.target_transform = target_transform
|
|
134
|
+
|
|
135
|
+
def split(
|
|
136
|
+
self,
|
|
137
|
+
by: str | list[int] | list[list[int]] | dict[str, list[int]],
|
|
138
|
+
) -> dict[str, FeaturesConcatDataset]:
|
|
139
|
+
"""Split the dataset based on information listed in its description.
|
|
140
|
+
|
|
141
|
+
The format could be based on a DataFrame or based on indices.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
by : str | list | dict
|
|
146
|
+
If ``by`` is a string, splitting is performed based on the
|
|
147
|
+
description DataFrame column with this name.
|
|
148
|
+
If ``by`` is a (list of) list of integers, the position in the first
|
|
149
|
+
list corresponds to the split id and the integers to the
|
|
150
|
+
datapoints of that split.
|
|
151
|
+
If a dict then each key will be used in the returned
|
|
152
|
+
splits dict and each value should be a list of int.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
splits : dict
|
|
157
|
+
A dictionary with the name of the split (a string) as key and the
|
|
158
|
+
dataset as value.
|
|
159
|
+
"""
|
|
160
|
+
if isinstance(by, str):
|
|
161
|
+
split_ids = {
|
|
162
|
+
k: list(v) for k, v in self.description.groupby(by).groups.items()
|
|
163
|
+
}
|
|
164
|
+
elif isinstance(by, dict):
|
|
165
|
+
split_ids = by
|
|
166
|
+
else:
|
|
167
|
+
# assume list(int)
|
|
168
|
+
if not isinstance(by[0], list):
|
|
169
|
+
by = [by]
|
|
170
|
+
# assume list(list(int))
|
|
171
|
+
split_ids = {split_i: split for split_i, split in enumerate(by)}
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
str(split_name): FeaturesConcatDataset(
|
|
175
|
+
[self.datasets[ds_ind] for ds_ind in ds_inds],
|
|
176
|
+
target_transform=self.target_transform,
|
|
177
|
+
)
|
|
178
|
+
for split_name, ds_inds in split_ids.items()
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def get_metadata(self) -> pd.DataFrame:
|
|
182
|
+
"""Concatenate the metadata and description of the wrapped Epochs.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
metadata : pd.DataFrame
|
|
187
|
+
DataFrame containing as many rows as there are windows in the
|
|
188
|
+
BaseConcatDataset, with the metadata and description information
|
|
189
|
+
for each window.
|
|
190
|
+
"""
|
|
191
|
+
if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
|
|
192
|
+
raise TypeError(
|
|
193
|
+
"Metadata dataframe can only be computed when all "
|
|
194
|
+
"datasets are FeaturesDataset."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
all_dfs = list()
|
|
198
|
+
for ds in self.datasets:
|
|
199
|
+
df = ds.metadata
|
|
200
|
+
for k, v in ds.description.items():
|
|
201
|
+
df[k] = v
|
|
202
|
+
all_dfs.append(df)
|
|
203
|
+
|
|
204
|
+
return pd.concat(all_dfs)
|
|
205
|
+
|
|
206
|
+
def save(self, path: str, overwrite: bool = False, offset: int = 0):
|
|
207
|
+
"""Save datasets to files by creating one subdirectory for each dataset:
|
|
208
|
+
path/
|
|
209
|
+
0/
|
|
210
|
+
0-feat.parquet
|
|
211
|
+
metadata_df.pkl
|
|
212
|
+
description.json
|
|
213
|
+
raw-info.fif (if raw info was saved)
|
|
214
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
215
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
216
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
217
|
+
features_kwargs.json
|
|
218
|
+
1/
|
|
219
|
+
1-feat.parquet
|
|
220
|
+
metadata_df.pkl
|
|
221
|
+
description.json
|
|
222
|
+
raw-info.fif (if raw info was saved)
|
|
223
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
224
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
225
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
226
|
+
features_kwargs.json
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
path : str
|
|
231
|
+
Directory in which subdirectories are created to store
|
|
232
|
+
-feat.parquet and .json files to.
|
|
233
|
+
overwrite : bool
|
|
234
|
+
Whether to delete old subdirectories that will be saved to in this
|
|
235
|
+
call.
|
|
236
|
+
offset : int
|
|
237
|
+
If provided, the integer is added to the id of the dataset in the
|
|
238
|
+
concat. This is useful in the setting of very large datasets, where
|
|
239
|
+
one dataset has to be processed and saved at a time to account for
|
|
240
|
+
its original position.
|
|
241
|
+
"""
|
|
242
|
+
if len(self.datasets) == 0:
|
|
243
|
+
raise ValueError("Expect at least one dataset")
|
|
244
|
+
path_contents = os.listdir(path)
|
|
245
|
+
n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
|
|
246
|
+
for i_ds, ds in enumerate(self.datasets):
|
|
247
|
+
# remove subdirectory from list of untouched files / subdirectories
|
|
248
|
+
if str(i_ds + offset) in path_contents:
|
|
249
|
+
path_contents.remove(str(i_ds + offset))
|
|
250
|
+
# save_dir/i_ds/
|
|
251
|
+
sub_dir = os.path.join(path, str(i_ds + offset))
|
|
252
|
+
if os.path.exists(sub_dir):
|
|
253
|
+
if overwrite:
|
|
254
|
+
shutil.rmtree(sub_dir)
|
|
255
|
+
else:
|
|
256
|
+
raise FileExistsError(
|
|
257
|
+
f"Subdirectory {sub_dir} already exists. Please select"
|
|
258
|
+
f" a different directory, set overwrite=True, or "
|
|
259
|
+
f"resolve manually."
|
|
260
|
+
)
|
|
261
|
+
# save_dir/{i_ds+offset}/
|
|
262
|
+
os.makedirs(sub_dir)
|
|
263
|
+
# save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet
|
|
264
|
+
self._save_features(sub_dir, ds, i_ds, offset)
|
|
265
|
+
# save_dir/{i_ds+offset}/metadata_df.pkl
|
|
266
|
+
self._save_metadata(sub_dir, ds)
|
|
267
|
+
# save_dir/{i_ds+offset}/description.json
|
|
268
|
+
self._save_description(sub_dir, ds.description)
|
|
269
|
+
# save_dir/{i_ds+offset}/raw-info.fif
|
|
270
|
+
self._save_raw_info(sub_dir, ds)
|
|
271
|
+
# save_dir/{i_ds+offset}/raw_preproc_kwargs.json
|
|
272
|
+
# save_dir/{i_ds+offset}/window_kwargs.json
|
|
273
|
+
# save_dir/{i_ds+offset}/window_preproc_kwargs.json
|
|
274
|
+
# save_dir/{i_ds+offset}/features_kwargs.json
|
|
275
|
+
self._save_kwargs(sub_dir, ds)
|
|
276
|
+
if overwrite:
|
|
277
|
+
# the following will be True for all datasets preprocessed and
|
|
278
|
+
# stored in parallel with braindecode.preprocessing.preprocess
|
|
279
|
+
if i_ds + 1 + offset < n_sub_dirs:
|
|
280
|
+
warnings.warn(
|
|
281
|
+
f"The number of saved datasets ({i_ds + 1 + offset}) "
|
|
282
|
+
f"does not match the number of existing "
|
|
283
|
+
f"subdirectories ({n_sub_dirs}). You may now "
|
|
284
|
+
f"encounter a mix of differently preprocessed "
|
|
285
|
+
f"datasets!",
|
|
286
|
+
UserWarning,
|
|
287
|
+
)
|
|
288
|
+
# if path contains files or directories that were not touched, raise
|
|
289
|
+
# warning
|
|
290
|
+
if path_contents:
|
|
291
|
+
warnings.warn(
|
|
292
|
+
f"Chosen directory {path} contains other "
|
|
293
|
+
f"subdirectories or files {path_contents}."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def _save_features(sub_dir, ds, i_ds, offset):
|
|
298
|
+
parquet_file_name = f"{i_ds + offset}-feat.parquet"
|
|
299
|
+
parquet_file_path = os.path.join(sub_dir, parquet_file_name)
|
|
300
|
+
ds.features.to_parquet(parquet_file_path)
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _save_raw_info(sub_dir, ds):
|
|
304
|
+
if hasattr(ds, "raw_info"):
|
|
305
|
+
fif_file_name = "raw-info.fif"
|
|
306
|
+
fif_file_path = os.path.join(sub_dir, fif_file_name)
|
|
307
|
+
ds.raw_info.save(fif_file_path)
|
|
308
|
+
|
|
309
|
+
@staticmethod
|
|
310
|
+
def _save_kwargs(sub_dir, ds):
|
|
311
|
+
for kwargs_name in [
|
|
312
|
+
"raw_preproc_kwargs",
|
|
313
|
+
"window_kwargs",
|
|
314
|
+
"window_preproc_kwargs",
|
|
315
|
+
"features_kwargs",
|
|
316
|
+
]:
|
|
317
|
+
if hasattr(ds, kwargs_name):
|
|
318
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
319
|
+
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
320
|
+
kwargs = getattr(ds, kwargs_name)
|
|
321
|
+
if kwargs is not None:
|
|
322
|
+
with open(kwargs_file_path, "w") as f:
|
|
323
|
+
json.dump(kwargs, f)
|
|
324
|
+
|
|
325
|
+
def to_dataframe(
|
|
326
|
+
self, include_metadata=False, include_target=False, include_crop_inds=False
|
|
327
|
+
):
|
|
328
|
+
if include_metadata or (include_target and include_crop_inds):
|
|
329
|
+
dataframes = [
|
|
330
|
+
ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
|
|
331
|
+
for ds in self.datasets
|
|
332
|
+
]
|
|
333
|
+
elif include_target:
|
|
334
|
+
dataframes = [
|
|
335
|
+
ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
|
|
336
|
+
for ds in self.datasets
|
|
337
|
+
]
|
|
338
|
+
elif include_crop_inds:
|
|
339
|
+
dataframes = [
|
|
340
|
+
ds.metadata.drop("target", axis="columns").join(
|
|
341
|
+
ds.features, how="right", lsuffix="_metadata"
|
|
342
|
+
)
|
|
343
|
+
for ds in self.datasets
|
|
344
|
+
]
|
|
345
|
+
else:
|
|
346
|
+
dataframes = [ds.features for ds in self.datasets]
|
|
347
|
+
return pd.concat(dataframes, axis=0, ignore_index=True)
|
|
348
|
+
|
|
349
|
+
def _numeric_columns(self):
|
|
350
|
+
return self.datasets[0].features.select_dtypes(include=np.number).columns
|
|
351
|
+
|
|
352
|
+
def count(self, numeric_only=False, n_jobs=1):
|
|
353
|
+
stats = Parallel(n_jobs)(
|
|
354
|
+
delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only)
|
|
355
|
+
for ds in self.datasets
|
|
356
|
+
)
|
|
357
|
+
counts = np.array([s[0] for s in stats])
|
|
358
|
+
count = counts.sum(axis=0)
|
|
359
|
+
return pd.Series(count, index=self._numeric_columns())
|
|
360
|
+
|
|
361
|
+
def mean(self, numeric_only=False, n_jobs=1):
|
|
362
|
+
stats = Parallel(n_jobs)(
|
|
363
|
+
delayed(_compute_stats)(
|
|
364
|
+
ds, return_count=True, return_mean=True, numeric_only=numeric_only
|
|
365
|
+
)
|
|
366
|
+
for ds in self.datasets
|
|
367
|
+
)
|
|
368
|
+
counts, means = np.array([s[0] for s in stats]), np.array([s[1] for s in stats])
|
|
369
|
+
count = counts.sum(axis=0, keepdims=True)
|
|
370
|
+
mean = np.sum((counts / count) * means, axis=0)
|
|
371
|
+
return pd.Series(mean, index=self._numeric_columns())
|
|
372
|
+
|
|
373
|
+
def var(self, ddof=1, numeric_only=False, n_jobs=1):
|
|
374
|
+
stats = Parallel(n_jobs)(
|
|
375
|
+
delayed(_compute_stats)(
|
|
376
|
+
ds,
|
|
377
|
+
return_count=True,
|
|
378
|
+
return_mean=True,
|
|
379
|
+
return_var=True,
|
|
380
|
+
ddof=ddof,
|
|
381
|
+
numeric_only=numeric_only,
|
|
382
|
+
)
|
|
383
|
+
for ds in self.datasets
|
|
384
|
+
)
|
|
385
|
+
counts, means, variances = (
|
|
386
|
+
np.array([s[0] for s in stats]),
|
|
387
|
+
np.array([s[1] for s in stats]),
|
|
388
|
+
np.array([s[2] for s in stats]),
|
|
389
|
+
)
|
|
390
|
+
_, _, var = _pooled_var(counts, means, variances, ddof)
|
|
391
|
+
return pd.Series(var, index=self._numeric_columns())
|
|
392
|
+
|
|
393
|
+
def std(self, ddof=1, numeric_only=False, n_jobs=1):
|
|
394
|
+
return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
|
|
395
|
+
|
|
396
|
+
def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
|
|
397
|
+
stats = Parallel(n_jobs)(
|
|
398
|
+
delayed(_compute_stats)(
|
|
399
|
+
ds,
|
|
400
|
+
return_count=True,
|
|
401
|
+
return_mean=True,
|
|
402
|
+
return_var=True,
|
|
403
|
+
ddof=ddof,
|
|
404
|
+
numeric_only=numeric_only,
|
|
405
|
+
)
|
|
406
|
+
for ds in self.datasets
|
|
407
|
+
)
|
|
408
|
+
counts, means, variances = (
|
|
409
|
+
np.array([s[0] for s in stats]),
|
|
410
|
+
np.array([s[1] for s in stats]),
|
|
411
|
+
np.array([s[2] for s in stats]),
|
|
412
|
+
)
|
|
413
|
+
_, mean, var = _pooled_var(counts, means, variances, ddof)
|
|
414
|
+
std = np.sqrt(var) + eps
|
|
415
|
+
for ds in self.datasets:
|
|
416
|
+
ds.features = (ds.features - mean) / std
|
|
417
|
+
|
|
418
|
+
@staticmethod
|
|
419
|
+
def _enforce_inplace_operations(func_name, kwargs):
|
|
420
|
+
if "inplace" in kwargs and kwargs["inplace"] is False:
|
|
421
|
+
raise ValueError(
|
|
422
|
+
f"{func_name} only works inplace, please change "
|
|
423
|
+
+ "to inplace=True (default)."
|
|
424
|
+
)
|
|
425
|
+
kwargs["inplace"] = True
|
|
426
|
+
|
|
427
|
+
def fillna(self, *args, **kwargs):
|
|
428
|
+
FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs)
|
|
429
|
+
for ds in self.datasets:
|
|
430
|
+
ds.features.fillna(*args, **kwargs)
|
|
431
|
+
|
|
432
|
+
def replace(self, *args, **kwargs):
|
|
433
|
+
FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs)
|
|
434
|
+
for ds in self.datasets:
|
|
435
|
+
ds.features.replace(*args, **kwargs)
|
|
436
|
+
|
|
437
|
+
def interpolate(self, *args, **kwargs):
|
|
438
|
+
FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs)
|
|
439
|
+
for ds in self.datasets:
|
|
440
|
+
ds.features.interpolate(*args, **kwargs)
|
|
441
|
+
|
|
442
|
+
def dropna(self, *args, **kwargs):
|
|
443
|
+
FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs)
|
|
444
|
+
for ds in self.datasets:
|
|
445
|
+
ds.features.dropna(*args, **kwargs)
|
|
446
|
+
|
|
447
|
+
def drop(self, *args, **kwargs):
|
|
448
|
+
FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs)
|
|
449
|
+
for ds in self.datasets:
|
|
450
|
+
ds.features.drop(*args, **kwargs)
|
|
451
|
+
|
|
452
|
+
def join(self, concat_dataset: FeaturesConcatDataset, **kwargs):
|
|
453
|
+
assert len(self.datasets) == len(concat_dataset.datasets)
|
|
454
|
+
for ds1, ds2 in zip(self.datasets, concat_dataset.datasets):
|
|
455
|
+
assert len(ds1) == len(ds2)
|
|
456
|
+
ds1.features.join(ds2, **kwargs)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import List, Type
|
|
3
|
+
|
|
4
|
+
from .extractors import (
|
|
5
|
+
BivariateFeature,
|
|
6
|
+
DirectedBivariateFeature,
|
|
7
|
+
FeatureExtractor,
|
|
8
|
+
MultivariateFeature,
|
|
9
|
+
UnivariateFeature,
|
|
10
|
+
_get_underlying_func,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FeaturePredecessor:
|
|
15
|
+
def __init__(self, *parent_extractor_type: List[Type]):
|
|
16
|
+
parent_cls = parent_extractor_type
|
|
17
|
+
if not parent_cls:
|
|
18
|
+
parent_cls = [FeatureExtractor]
|
|
19
|
+
for p_cls in parent_cls:
|
|
20
|
+
assert issubclass(p_cls, FeatureExtractor)
|
|
21
|
+
self.parent_extractor_type = parent_cls
|
|
22
|
+
|
|
23
|
+
def __call__(self, func: Callable):
|
|
24
|
+
f = _get_underlying_func(func)
|
|
25
|
+
f.parent_extractor_type = self.parent_extractor_type
|
|
26
|
+
return func
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FeatureKind:
|
|
30
|
+
def __init__(self, feature_kind: MultivariateFeature):
|
|
31
|
+
self.feature_kind = feature_kind
|
|
32
|
+
|
|
33
|
+
def __call__(self, func):
|
|
34
|
+
f = _get_underlying_func(func)
|
|
35
|
+
f.feature_kind = self.feature_kind
|
|
36
|
+
return func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Syntax sugar
|
|
40
|
+
univariate_feature = FeatureKind(UnivariateFeature())
|
|
41
|
+
bivariate_feature = FeatureKind(BivariateFeature())
|
|
42
|
+
directed_bivariate_feature = FeatureKind(DirectedBivariateFeature())
|
|
43
|
+
multivariate_feature = FeatureKind(MultivariateFeature())
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numba.core.dispatcher import Dispatcher
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_underlying_func(func):
|
|
11
|
+
f = func
|
|
12
|
+
if isinstance(f, partial):
|
|
13
|
+
f = f.func
|
|
14
|
+
if isinstance(f, Dispatcher):
|
|
15
|
+
f = f.py_func
|
|
16
|
+
return f
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FitableFeature(ABC):
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self._is_fitted = False
|
|
22
|
+
self.clear()
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def clear(self):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def partial_fit(self, *x, y=None):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def fit(self):
|
|
33
|
+
self._is_fitted = True
|
|
34
|
+
|
|
35
|
+
def __call__(self, *args, **kwargs):
|
|
36
|
+
if not self._is_fitted:
|
|
37
|
+
raise RuntimeError(
|
|
38
|
+
f"{self.__class__} cannot be called, it has to be fitted first."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FeatureExtractor(FitableFeature):
|
|
43
|
+
def __init__(
|
|
44
|
+
self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
|
|
45
|
+
):
|
|
46
|
+
self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
|
|
47
|
+
self._is_fitable = self._check_is_fitable(feature_extractors)
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
# bypassing FeaturePredecessor to avoid circular import
|
|
51
|
+
if not hasattr(self, "parent_extractor_type"):
|
|
52
|
+
self.parent_extractor_type = [FeatureExtractor]
|
|
53
|
+
|
|
54
|
+
self.preprocess_kwargs = preprocess_kwargs
|
|
55
|
+
if self.preprocess_kwargs is None:
|
|
56
|
+
self.preprocess_kwargs = dict()
|
|
57
|
+
self.features_kwargs = {
|
|
58
|
+
"preprocess_kwargs": preprocess_kwargs,
|
|
59
|
+
}
|
|
60
|
+
for fn, fe in feature_extractors.items():
|
|
61
|
+
if isinstance(fe, FeatureExtractor):
|
|
62
|
+
self.features_kwargs[fn] = fe.features_kwargs
|
|
63
|
+
if isinstance(fe, partial):
|
|
64
|
+
self.features_kwargs[fn] = fe.keywords
|
|
65
|
+
|
|
66
|
+
def _validate_execution_tree(self, feature_extractors):
|
|
67
|
+
for fname, f in feature_extractors.items():
|
|
68
|
+
f = _get_underlying_func(f)
|
|
69
|
+
pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
|
|
70
|
+
assert type(self) in pe_type
|
|
71
|
+
return feature_extractors
|
|
72
|
+
|
|
73
|
+
def _check_is_fitable(self, feature_extractors):
|
|
74
|
+
is_fitable = False
|
|
75
|
+
for fname, f in feature_extractors.items():
|
|
76
|
+
if isinstance(f, FeatureExtractor):
|
|
77
|
+
is_fitable = f._is_fitable
|
|
78
|
+
else:
|
|
79
|
+
f = _get_underlying_func(f)
|
|
80
|
+
if isinstance(f, FitableFeature):
|
|
81
|
+
is_fitable = True
|
|
82
|
+
if is_fitable:
|
|
83
|
+
break
|
|
84
|
+
return is_fitable
|
|
85
|
+
|
|
86
|
+
def preprocess(self, *x, **kwargs):
|
|
87
|
+
return (*x,)
|
|
88
|
+
|
|
89
|
+
def feature_channel_names(self, ch_names):
|
|
90
|
+
return [""]
|
|
91
|
+
|
|
92
|
+
def __call__(self, *x, _batch_size=None, _ch_names=None):
|
|
93
|
+
assert _batch_size is not None
|
|
94
|
+
assert _ch_names is not None
|
|
95
|
+
if self._is_fitable:
|
|
96
|
+
super().__call__()
|
|
97
|
+
results_dict = dict()
|
|
98
|
+
z = self.preprocess(*x, **self.preprocess_kwargs)
|
|
99
|
+
for fname, f in self.feature_extractors_dict.items():
|
|
100
|
+
if isinstance(f, FeatureExtractor):
|
|
101
|
+
r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
|
|
102
|
+
else:
|
|
103
|
+
r = f(*z)
|
|
104
|
+
f = _get_underlying_func(f)
|
|
105
|
+
if hasattr(f, "feature_kind"):
|
|
106
|
+
r = f.feature_kind(r, _ch_names=_ch_names)
|
|
107
|
+
if not isinstance(fname, str) or not fname:
|
|
108
|
+
if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
|
|
109
|
+
fname = ""
|
|
110
|
+
else:
|
|
111
|
+
fname = f.__name__
|
|
112
|
+
if isinstance(r, dict):
|
|
113
|
+
if fname:
|
|
114
|
+
fname += "_"
|
|
115
|
+
for k, v in r.items():
|
|
116
|
+
self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
|
|
117
|
+
else:
|
|
118
|
+
self._add_feature_to_dict(results_dict, fname, r, _batch_size)
|
|
119
|
+
return results_dict
|
|
120
|
+
|
|
121
|
+
def _add_feature_to_dict(self, results_dict, name, value, batch_size):
|
|
122
|
+
if not isinstance(value, np.ndarray):
|
|
123
|
+
results_dict[name] = value
|
|
124
|
+
else:
|
|
125
|
+
assert value.shape[0] == batch_size
|
|
126
|
+
results_dict[name] = value
|
|
127
|
+
|
|
128
|
+
def clear(self):
|
|
129
|
+
if not self._is_fitable:
|
|
130
|
+
return
|
|
131
|
+
for fname, f in self.feature_extractors_dict.items():
|
|
132
|
+
f = _get_underlying_func(f)
|
|
133
|
+
if isinstance(f, FitableFeature):
|
|
134
|
+
f.clear()
|
|
135
|
+
|
|
136
|
+
def partial_fit(self, *x, y=None):
|
|
137
|
+
if not self._is_fitable:
|
|
138
|
+
return
|
|
139
|
+
z = self.preprocess(*x, **self.preprocess_kwargs)
|
|
140
|
+
for fname, f in self.feature_extractors_dict.items():
|
|
141
|
+
f = _get_underlying_func(f)
|
|
142
|
+
if isinstance(f, FitableFeature):
|
|
143
|
+
f.partial_fit(*z, y=y)
|
|
144
|
+
|
|
145
|
+
def fit(self):
|
|
146
|
+
if not self._is_fitable:
|
|
147
|
+
return
|
|
148
|
+
for fname, f in self.feature_extractors_dict.items():
|
|
149
|
+
f = _get_underlying_func(f)
|
|
150
|
+
if isinstance(f, FitableFeature):
|
|
151
|
+
f.fit()
|
|
152
|
+
super().fit()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class MultivariateFeature:
|
|
156
|
+
def __call__(self, x, _ch_names=None):
|
|
157
|
+
assert _ch_names is not None
|
|
158
|
+
f_channels = self.feature_channel_names(_ch_names)
|
|
159
|
+
if isinstance(x, dict):
|
|
160
|
+
r = dict()
|
|
161
|
+
for k, v in x.items():
|
|
162
|
+
r.update(self._array_to_dict(v, f_channels, k))
|
|
163
|
+
return r
|
|
164
|
+
return self._array_to_dict(x, f_channels)
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def _array_to_dict(x, f_channels, name=""):
|
|
168
|
+
assert isinstance(x, np.ndarray)
|
|
169
|
+
if len(f_channels) == 0:
|
|
170
|
+
assert x.ndim == 1
|
|
171
|
+
if name:
|
|
172
|
+
return {name: x}
|
|
173
|
+
return x
|
|
174
|
+
assert x.shape[1] == len(f_channels)
|
|
175
|
+
x = x.swapaxes(0, 1)
|
|
176
|
+
names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
|
|
177
|
+
return dict(zip(names, x))
|
|
178
|
+
|
|
179
|
+
def feature_channel_names(self, ch_names):
|
|
180
|
+
return []
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class UnivariateFeature(MultivariateFeature):
|
|
184
|
+
def feature_channel_names(self, ch_names):
|
|
185
|
+
return ch_names
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class BivariateFeature(MultivariateFeature):
|
|
189
|
+
def __init__(self, *args, channel_pair_format="{}<>{}"):
|
|
190
|
+
super().__init__(*args)
|
|
191
|
+
self.channel_pair_format = channel_pair_format
|
|
192
|
+
|
|
193
|
+
@staticmethod
|
|
194
|
+
def get_pair_iterators(n):
|
|
195
|
+
return np.triu_indices(n, 1)
|
|
196
|
+
|
|
197
|
+
def feature_channel_names(self, ch_names):
|
|
198
|
+
return [
|
|
199
|
+
self.channel_pair_format.format(ch_names[i], ch_names[j])
|
|
200
|
+
for i, j in zip(*self.get_pair_iterators(len(ch_names)))
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class DirectedBivariateFeature(BivariateFeature):
|
|
205
|
+
@staticmethod
|
|
206
|
+
def get_pair_iterators(n):
|
|
207
|
+
return [
|
|
208
|
+
np.append(a, b)
|
|
209
|
+
for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
|
|
210
|
+
]
|