eegdash 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_config.py ADDED
@@ -0,0 +1,28 @@
1
+ config = {
2
+ "required_fields": ["data_name"],
3
+ "attributes": {
4
+ "data_name": "str",
5
+ "dataset": "str",
6
+ "bidspath": "str",
7
+ "subject": "str",
8
+ "task": "str",
9
+ "session": "str",
10
+ "run": "str",
11
+ "sampling_frequency": "float",
12
+ "modality": "str",
13
+ "nchans": "int",
14
+ "ntimes": "int"
15
+ },
16
+ "description_fields": ["subject", "session", "run", "task", "age", "gender", "sex"],
17
+ "bids_dependencies_files": [
18
+ "dataset_description.json",
19
+ "participants.tsv",
20
+ "events.tsv",
21
+ "events.json",
22
+ "eeg.json",
23
+ "electrodes.tsv",
24
+ "channels.tsv",
25
+ "coordsystem.json"
26
+ ],
27
+ "accepted_query_fields": ["data_name", "dataset"]
28
+ }
eegdash/data_utils.py CHANGED
@@ -17,6 +17,7 @@ import mne_bids
17
17
  from mne_bids import (
18
18
  BIDSPath,
19
19
  )
20
+ from bids import BIDSLayout
20
21
 
21
22
  class EEGDashBaseDataset(BaseDataset):
22
23
  """Returns samples from an mne.io.Raw object along with a target.
@@ -96,7 +97,7 @@ class EEGDashBaseDataset(BaseDataset):
96
97
 
97
98
  def __len__(self):
98
99
  if self._raw is None:
99
- return self.record['rawdatainfo']['ntimes']
100
+ return int(self.record['ntimes'] * self.record['sampling_frequency'])
100
101
  else:
101
102
  return len(self._raw)
102
103
 
@@ -216,39 +217,49 @@ class EEGDashBaseRaw(BaseRaw):
216
217
  _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
217
218
 
218
219
 
219
- class BIDSDataset():
220
+ class EEGBIDSDataset():
220
221
  ALLOWED_FILE_FORMAT = ['eeglab', 'brainvision', 'biosemi', 'european']
221
- RAW_EXTENSION = {
222
- 'eeglab': '.set',
223
- 'brainvision': '.vhdr',
224
- 'biosemi': '.bdf',
225
- 'european': '.edf'
226
- }
222
+ RAW_EXTENSIONS = {
223
+ '.set': ['.set', '.fdt'], # eeglab
224
+ '.edf': ['.edf'], # european
225
+ '.vhdr': ['.eeg', '.vhdr', '.vmrk', '.dat', '.raw'], # brainvision
226
+ '.bdf': ['.bdf'], # biosemi
227
+ }
227
228
  METADATA_FILE_EXTENSIONS = ['eeg.json', 'channels.tsv', 'electrodes.tsv', 'events.tsv', 'events.json']
228
229
  def __init__(self,
229
230
  data_dir=None, # location of bids dataset
230
231
  dataset='', # dataset name
231
- raw_format='eeglab', # format of raw data
232
232
  ):
233
233
  if data_dir is None or not os.path.exists(data_dir):
234
234
  raise ValueError('data_dir must be specified and must exist')
235
235
  self.bidsdir = Path(data_dir)
236
236
  self.dataset = dataset
237
237
  assert str(self.bidsdir).endswith(self.dataset)
238
-
239
- if raw_format.lower() not in self.ALLOWED_FILE_FORMAT:
240
- raise ValueError('raw_format must be one of {}'.format(self.ALLOWED_FILE_FORMAT))
241
- self.raw_format = raw_format.lower()
242
-
243
- # get all .set files in the bids directory
244
- temp_dir = (Path().resolve() / 'data')
245
- if not os.path.exists(temp_dir):
246
- os.mkdir(temp_dir)
247
- if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
248
- self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
249
- np.save(temp_dir / f'{dataset}_files.npy', self.files)
250
- else:
251
- self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
238
+ self.layout = BIDSLayout(data_dir)
239
+
240
+ # get all recording files in the bids directory
241
+ self.files = self.get_recordings(self.layout)
242
+ assert len(self.files) > 0, ValueError('Unable to construct EEG dataset. No EEG recordings found.')
243
+ assert self.check_eeg_dataset(), ValueError('Dataset is not an EEG dataset.')
244
+ # temp_dir = (Path().resolve() / 'data')
245
+ # if not os.path.exists(temp_dir):
246
+ # os.mkdir(temp_dir)
247
+ # if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
248
+ # self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
249
+ # np.save(temp_dir / f'{dataset}_files.npy', self.files)
250
+ # else:
251
+ # self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
252
+
253
+ def check_eeg_dataset(self):
254
+ return self.get_bids_file_attribute('modality', self.files[0]).lower() == 'eeg'
255
+
256
+ def get_recordings(self, layout:BIDSLayout):
257
+ files = []
258
+ for ext, exts in self.RAW_EXTENSIONS.items():
259
+ files = layout.get(extension=ext, return_type='filename')
260
+ if files:
261
+ break
262
+ return files
252
263
 
253
264
  def get_relative_bidspath(self, filename):
254
265
  bids_parent_dir = self.bidsdir.parent
@@ -301,11 +312,6 @@ class BIDSDataset():
301
312
  filepath = path / file
302
313
  bids_files.append(filepath)
303
314
 
304
- # cur_file_basename = file[:file.rfind('_')] # TODO: change to just search for any file with extension
305
- # if file.endswith(extension) and cur_file_basename in basename:
306
- # filepath = path / file
307
- # bids_files.append(filepath)
308
-
309
315
  # check if file is in top level directory
310
316
  if any(file in os.listdir(path) for file in top_level_files):
311
317
  return bids_files
@@ -338,7 +344,7 @@ class BIDSDataset():
338
344
 
339
345
  def scan_directory(self, directory, extension):
340
346
  result_files = []
341
- directory_to_ignore = ['.git']
347
+ directory_to_ignore = ['.git', '.datalad', 'derivatives', 'code']
342
348
  with os.scandir(directory) as entries:
343
349
  for entry in entries:
344
350
  if entry.is_file() and entry.name.endswith(extension):
@@ -419,32 +425,22 @@ class BIDSDataset():
419
425
  json_dict.update(json.load(f))
420
426
  return json_dict
421
427
 
422
- def sfreq(self, data_filepath):
423
- json_files = self.get_bids_metadata_files(data_filepath, 'eeg.json')
424
- if len(json_files) == 0:
425
- raise ValueError('No eeg.json found')
426
-
427
- metadata = self.resolve_bids_json(json_files)
428
- if 'SamplingFrequency' not in metadata:
429
- raise ValueError('SamplingFrequency not found in metadata')
430
- else:
431
- return metadata['SamplingFrequency']
432
-
433
- def task(self, data_filepath):
434
- return self.get_property_from_filename('task', data_filepath)
435
-
436
- def session(self, data_filepath):
437
- return self.get_property_from_filename('session', data_filepath)
438
-
439
- def run(self, data_filepath):
440
- return self.get_property_from_filename('run', data_filepath)
441
-
442
- def subject(self, data_filepath):
443
- return self.get_property_from_filename('sub', data_filepath)
444
-
445
- def num_channels(self, data_filepath):
446
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
447
- return len(channels_tsv)
428
+ def get_bids_file_attribute(self, attribute, data_filepath):
429
+ entities = self.layout.parse_file_entities(data_filepath)
430
+ bidsfile = self.layout.get(**entities)[0]
431
+ attributes = bidsfile.get_entities(metadata='all')
432
+ attribute_mapping = {
433
+ 'sfreq': 'SamplingFrequency',
434
+ 'modality': 'datatype',
435
+ 'task': 'task',
436
+ 'session': 'session',
437
+ 'run': 'run',
438
+ 'subject': 'subject',
439
+ 'ntimes': 'RecordingDuration',
440
+ 'nchans': 'EEGChannelCount'
441
+ }
442
+ attribute_value = attributes.get(attribute_mapping.get(attribute), None)
443
+ return attribute_value
448
444
 
449
445
  def channel_labels(self, data_filepath):
450
446
  channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
@@ -462,9 +458,12 @@ class BIDSDataset():
462
458
  def subject_participant_tsv(self, data_filepath):
463
459
  '''Get participants_tsv info of a subject based on filepath'''
464
460
  participants_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'participants.tsv')[0], sep='\t')
461
+ # if participants_tsv is not empty
462
+ if participants_tsv.empty:
463
+ return {}
465
464
  # set 'participant_id' as index
466
465
  participants_tsv.set_index('participant_id', inplace=True)
467
- subject = f'sub-{self.subject(data_filepath)}'
466
+ subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
468
467
  return participants_tsv.loc[subject].to_dict()
469
468
 
470
469
  def eeg_json(self, data_filepath):
@@ -0,0 +1,25 @@
1
+ # Features datasets
2
+ from .datasets import FeaturesDataset, FeaturesConcatDataset
3
+ from .serialization import load_features_concat_dataset
4
+
5
+ # Feature extraction
6
+ from .extractors import (
7
+ FeatureExtractor,
8
+ FitableFeature,
9
+ UnivariateFeature,
10
+ BivariateFeature,
11
+ DirectedBivariateFeature,
12
+ MultivariateFeature,
13
+ )
14
+ from .decorators import (
15
+ FeaturePredecessor,
16
+ FeatureKind,
17
+ univariate_feature,
18
+ bivariate_feature,
19
+ directed_bivariate_feature,
20
+ multivariate_feature,
21
+ )
22
+ from .utils import extract_features, fit_feature_extractors
23
+
24
+ # Features:
25
+ from .feature_bank import *
@@ -0,0 +1,453 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import json
4
+ import shutil
5
+ import warnings
6
+ from typing import Dict, no_type_check
7
+ from collections.abc import Callable, Iterable
8
+ import numpy as np
9
+ import pandas as pd
10
+ from joblib import Parallel, delayed
11
+ from braindecode.datasets.base import (
12
+ EEGWindowsDataset,
13
+ BaseConcatDataset,
14
+ _create_description,
15
+ )
16
+
17
+
18
+ class FeaturesDataset(EEGWindowsDataset):
19
+ """Returns samples from a pandas DataFrame object along with a target.
20
+
21
+ Dataset which serves samples from a pandas DataFrame object along with a
22
+ target. The target is unique for the dataset, and is obtained through the
23
+ `description` attribute.
24
+
25
+ Parameters
26
+ ----------
27
+ features : a pandas DataFrame
28
+ Tabular data.
29
+ description : dict | pandas.Series | None
30
+ Holds additional description about the continuous signal / subject.
31
+ transform : callable | None
32
+ On-the-fly transform applied to the example before it is returned.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ features: pd.DataFrame,
38
+ metadata: pd.DataFrame | None = None,
39
+ description: dict | pd.Series | None = None,
40
+ transform: Callable | None = None,
41
+ raw_info: Dict | None = None,
42
+ raw_preproc_kwargs: Dict | None = None,
43
+ window_kwargs: Dict | None = None,
44
+ window_preproc_kwargs: Dict | None = None,
45
+ features_kwargs: Dict | None = None,
46
+ ):
47
+ self.features = features
48
+ self.n_features = features.columns.size
49
+ self.metadata = metadata
50
+ self._description = _create_description(description)
51
+ self.transform = transform
52
+ self.raw_info = raw_info
53
+ self.raw_preproc_kwargs = raw_preproc_kwargs
54
+ self.window_kwargs = window_kwargs
55
+ self.window_preproc_kwargs = window_preproc_kwargs
56
+ self.features_kwargs = features_kwargs
57
+
58
+ self.crop_inds = metadata.loc[
59
+ :, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
60
+ ].to_numpy()
61
+ self.y = metadata.loc[:, "target"].to_list()
62
+
63
+ def __getitem__(self, index):
64
+ crop_inds = self.crop_inds[index].tolist()
65
+ X = self.features.iloc[index].to_numpy()
66
+ X = X.copy()
67
+ X.astype("float32")
68
+ if self.transform is not None:
69
+ X = self.transform(X)
70
+ y = self.y[index]
71
+ return X, y, crop_inds
72
+
73
+ def __len__(self):
74
+ return len(self.features.index)
75
+
76
+
77
+ def _compute_stats(
78
+ ds: FeaturesDataset,
79
+ return_count=False,
80
+ return_mean=False,
81
+ return_var=False,
82
+ ddof=1,
83
+ numeric_only=False,
84
+ ):
85
+ res = []
86
+ if return_count:
87
+ res.append(ds.features.count(numeric_only=numeric_only))
88
+ if return_mean:
89
+ res.append(ds.features.mean(numeric_only=numeric_only))
90
+ if return_var:
91
+ res.append(ds.features.var(ddof=ddof, numeric_only=numeric_only))
92
+ return tuple(res)
93
+
94
+
95
+ def _pooled_var(counts, means, variances, ddof):
96
+ count = counts.sum(axis=0)
97
+ mean = np.sum((counts / count) * means, axis=0)
98
+ var = np.sum(((counts - ddof) / (count - ddof)) * variances, axis=0)
99
+ var[:] += np.sum((counts / (count - ddof)) * (means**2), axis=0)
100
+ var[:] -= (count / (count - ddof)) * (mean**2)
101
+ var[:] = var.clip(min=0)
102
+ return count, mean, var
103
+
104
+
105
+ class FeaturesConcatDataset(BaseConcatDataset):
106
+ """A base class for concatenated datasets.
107
+
108
+ Holds either mne.Raw or mne.Epoch in self.datasets and has
109
+ a pandas DataFrame with additional description.
110
+
111
+ Parameters
112
+ ----------
113
+ list_of_ds : list
114
+ list of BaseDataset, BaseConcatDataset or WindowsDataset
115
+ target_transform : callable | None
116
+ Optional function to call on targets before returning them.
117
+
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ list_of_ds: list[FeaturesDataset] | None = None,
123
+ target_transform: Callable | None = None,
124
+ ):
125
+ # if we get a list of FeaturesConcatDataset, get all the individual datasets
126
+ if list_of_ds and isinstance(list_of_ds[0], FeaturesConcatDataset):
127
+ list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
128
+ super().__init__(list_of_ds)
129
+
130
+ self.target_transform = target_transform
131
+
132
+ def split(
133
+ self,
134
+ by: str | list[int] | list[list[int]] | dict[str, list[int]],
135
+ ) -> dict[str, FeaturesConcatDataset]:
136
+ """Split the dataset based on information listed in its description.
137
+
138
+ The format could be based on a DataFrame or based on indices.
139
+
140
+ Parameters
141
+ ----------
142
+ by : str | list | dict
143
+ If ``by`` is a string, splitting is performed based on the
144
+ description DataFrame column with this name.
145
+ If ``by`` is a (list of) list of integers, the position in the first
146
+ list corresponds to the split id and the integers to the
147
+ datapoints of that split.
148
+ If a dict then each key will be used in the returned
149
+ splits dict and each value should be a list of int.
150
+
151
+ Returns
152
+ -------
153
+ splits : dict
154
+ A dictionary with the name of the split (a string) as key and the
155
+ dataset as value.
156
+ """
157
+ if isinstance(by, str):
158
+ split_ids = {
159
+ k: list(v) for k, v in self.description.groupby(by).groups.items()
160
+ }
161
+ elif isinstance(by, dict):
162
+ split_ids = by
163
+ else:
164
+ # assume list(int)
165
+ if not isinstance(by[0], list):
166
+ by = [by]
167
+ # assume list(list(int))
168
+ split_ids = {split_i: split for split_i, split in enumerate(by)}
169
+
170
+ return {
171
+ str(split_name): FeaturesConcatDataset(
172
+ [self.datasets[ds_ind] for ds_ind in ds_inds],
173
+ target_transform=self.target_transform,
174
+ )
175
+ for split_name, ds_inds in split_ids.items()
176
+ }
177
+
178
+ def get_metadata(self) -> pd.DataFrame:
179
+ """Concatenate the metadata and description of the wrapped Epochs.
180
+
181
+ Returns
182
+ -------
183
+ metadata : pd.DataFrame
184
+ DataFrame containing as many rows as there are windows in the
185
+ BaseConcatDataset, with the metadata and description information
186
+ for each window.
187
+ """
188
+ if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
189
+ raise TypeError(
190
+ "Metadata dataframe can only be computed when all "
191
+ "datasets are FeaturesDataset."
192
+ )
193
+
194
+ all_dfs = list()
195
+ for ds in self.datasets:
196
+ df = ds.metadata
197
+ for k, v in ds.description.items():
198
+ df[k] = v
199
+ all_dfs.append(df)
200
+
201
+ return pd.concat(all_dfs)
202
+
203
+ def save(self, path: str, overwrite: bool = False, offset: int = 0):
204
+ """Save datasets to files by creating one subdirectory for each dataset:
205
+ path/
206
+ 0/
207
+ 0-feat.parquet
208
+ metadata_df.pkl
209
+ description.json
210
+ raw-info.fif (if raw info was saved)
211
+ raw_preproc_kwargs.json (if raws were preprocessed)
212
+ window_kwargs.json (if this is a windowed dataset)
213
+ window_preproc_kwargs.json (if windows were preprocessed)
214
+ features_kwargs.json
215
+ 1/
216
+ 1-feat.parquet
217
+ metadata_df.pkl
218
+ description.json
219
+ raw-info.fif (if raw info was saved)
220
+ raw_preproc_kwargs.json (if raws were preprocessed)
221
+ window_kwargs.json (if this is a windowed dataset)
222
+ window_preproc_kwargs.json (if windows were preprocessed)
223
+ features_kwargs.json
224
+
225
+ Parameters
226
+ ----------
227
+ path : str
228
+ Directory in which subdirectories are created to store
229
+ -feat.parquet and .json files to.
230
+ overwrite : bool
231
+ Whether to delete old subdirectories that will be saved to in this
232
+ call.
233
+ offset : int
234
+ If provided, the integer is added to the id of the dataset in the
235
+ concat. This is useful in the setting of very large datasets, where
236
+ one dataset has to be processed and saved at a time to account for
237
+ its original position.
238
+ """
239
+ if len(self.datasets) == 0:
240
+ raise ValueError("Expect at least one dataset")
241
+ path_contents = os.listdir(path)
242
+ n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
243
+ for i_ds, ds in enumerate(self.datasets):
244
+ # remove subdirectory from list of untouched files / subdirectories
245
+ if str(i_ds + offset) in path_contents:
246
+ path_contents.remove(str(i_ds + offset))
247
+ # save_dir/i_ds/
248
+ sub_dir = os.path.join(path, str(i_ds + offset))
249
+ if os.path.exists(sub_dir):
250
+ if overwrite:
251
+ shutil.rmtree(sub_dir)
252
+ else:
253
+ raise FileExistsError(
254
+ f"Subdirectory {sub_dir} already exists. Please select"
255
+ f" a different directory, set overwrite=True, or "
256
+ f"resolve manually."
257
+ )
258
+ # save_dir/{i_ds+offset}/
259
+ os.makedirs(sub_dir)
260
+ # save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet
261
+ self._save_features(sub_dir, ds, i_ds, offset)
262
+ # save_dir/{i_ds+offset}/metadata_df.pkl
263
+ self._save_metadata(sub_dir, ds)
264
+ # save_dir/{i_ds+offset}/description.json
265
+ self._save_description(sub_dir, ds.description)
266
+ # save_dir/{i_ds+offset}/raw-info.fif
267
+ self._save_raw_info(sub_dir, ds)
268
+ # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
269
+ # save_dir/{i_ds+offset}/window_kwargs.json
270
+ # save_dir/{i_ds+offset}/window_preproc_kwargs.json
271
+ # save_dir/{i_ds+offset}/features_kwargs.json
272
+ self._save_kwargs(sub_dir, ds)
273
+ if overwrite:
274
+ # the following will be True for all datasets preprocessed and
275
+ # stored in parallel with braindecode.preprocessing.preprocess
276
+ if i_ds + 1 + offset < n_sub_dirs:
277
+ warnings.warn(
278
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
279
+ f"does not match the number of existing "
280
+ f"subdirectories ({n_sub_dirs}). You may now "
281
+ f"encounter a mix of differently preprocessed "
282
+ f"datasets!",
283
+ UserWarning,
284
+ )
285
+ # if path contains files or directories that were not touched, raise
286
+ # warning
287
+ if path_contents:
288
+ warnings.warn(
289
+ f"Chosen directory {path} contains other "
290
+ f"subdirectories or files {path_contents}."
291
+ )
292
+
293
+ @staticmethod
294
+ def _save_features(sub_dir, ds, i_ds, offset):
295
+ parquet_file_name = f"{i_ds + offset}-feat.parquet"
296
+ parquet_file_path = os.path.join(sub_dir, parquet_file_name)
297
+ ds.features.to_parquet(parquet_file_path)
298
+
299
+ @staticmethod
300
+ def _save_raw_info(sub_dir, ds):
301
+ if hasattr(ds, "raw_info"):
302
+ fif_file_name = "raw-info.fif"
303
+ fif_file_path = os.path.join(sub_dir, fif_file_name)
304
+ ds.raw_info.save(fif_file_path)
305
+
306
+ @staticmethod
307
+ def _save_kwargs(sub_dir, ds):
308
+ for kwargs_name in [
309
+ "raw_preproc_kwargs",
310
+ "window_kwargs",
311
+ "window_preproc_kwargs",
312
+ "features_kwargs",
313
+ ]:
314
+ if hasattr(ds, kwargs_name):
315
+ kwargs_file_name = ".".join([kwargs_name, "json"])
316
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
317
+ kwargs = getattr(ds, kwargs_name)
318
+ if kwargs is not None:
319
+ with open(kwargs_file_path, "w") as f:
320
+ json.dump(kwargs, f)
321
+
322
+ def to_dataframe(
323
+ self, include_metadata=False, include_target=False, include_crop_inds=False
324
+ ):
325
+ if include_metadata or (include_target and include_crop_inds):
326
+ dataframes = [
327
+ ds.metadata.join(ds.features, how="right", lsuffix="_metadata")
328
+ for ds in self.datasets
329
+ ]
330
+ elif include_target:
331
+ dataframes = [
332
+ ds.features.join(ds.metadata["target"], how="left", rsuffix="_metadata")
333
+ for ds in self.datasets
334
+ ]
335
+ elif include_crop_inds:
336
+ dataframes = [
337
+ ds.metadata.drop("target", axis="columns").join(
338
+ ds.features, how="right", lsuffix="_metadata"
339
+ )
340
+ for ds in self.datasets
341
+ ]
342
+ else:
343
+ dataframes = [ds.features for ds in self.datasets]
344
+ return pd.concat(dataframes, axis=0, ignore_index=True)
345
+
346
+ def _numeric_columns(self):
347
+ return self.datasets[0].features.select_dtypes(include=np.number).columns
348
+
349
+ def count(self, numeric_only=False, n_jobs=1):
350
+ stats = Parallel(n_jobs)(
351
+ delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only)
352
+ for ds in self.datasets
353
+ )
354
+ counts = np.array([s[0] for s in stats])
355
+ count = counts.sum(axis=0)
356
+ return pd.Series(count, index=self._numeric_columns())
357
+
358
+ def mean(self, numeric_only=False, n_jobs=1):
359
+ stats = Parallel(n_jobs)(
360
+ delayed(_compute_stats)(
361
+ ds, return_count=True, return_mean=True, numeric_only=numeric_only
362
+ )
363
+ for ds in self.datasets
364
+ )
365
+ counts, means = np.array([s[0] for s in stats]), np.array([s[1] for s in stats])
366
+ count = counts.sum(axis=0, keepdims=True)
367
+ mean = np.sum((counts / count) * means, axis=0)
368
+ return pd.Series(mean, index=self._numeric_columns())
369
+
370
+ def var(self, ddof=1, numeric_only=False, n_jobs=1):
371
+ stats = Parallel(n_jobs)(
372
+ delayed(_compute_stats)(
373
+ ds,
374
+ return_count=True,
375
+ return_mean=True,
376
+ return_var=True,
377
+ ddof=ddof,
378
+ numeric_only=numeric_only,
379
+ )
380
+ for ds in self.datasets
381
+ )
382
+ counts, means, variances = (
383
+ np.array([s[0] for s in stats]),
384
+ np.array([s[1] for s in stats]),
385
+ np.array([s[2] for s in stats]),
386
+ )
387
+ _, _, var = _pooled_var(counts, means, variances, ddof)
388
+ return pd.Series(var, index=self._numeric_columns())
389
+
390
+ def std(self, ddof=1, numeric_only=False, n_jobs=1):
391
+ return np.sqrt(self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs))
392
+
393
+ def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
394
+ stats = Parallel(n_jobs)(
395
+ delayed(_compute_stats)(
396
+ ds,
397
+ return_count=True,
398
+ return_mean=True,
399
+ return_var=True,
400
+ ddof=ddof,
401
+ numeric_only=numeric_only,
402
+ )
403
+ for ds in self.datasets
404
+ )
405
+ counts, means, variances = (
406
+ np.array([s[0] for s in stats]),
407
+ np.array([s[1] for s in stats]),
408
+ np.array([s[2] for s in stats]),
409
+ )
410
+ _, mean, var = _pooled_var(counts, means, variances, ddof)
411
+ std = np.sqrt(var) + eps
412
+ for ds in self.datasets:
413
+ ds.features = (ds.features - mean) / std
414
+
415
+ @staticmethod
416
+ def _enforce_inplace_operations(func_name, kwargs):
417
+ if "inplace" in kwargs and kwargs["inplace"] is False:
418
+ raise ValueError(
419
+ f"{func_name} only works inplace, please change "
420
+ + "to inplace=True (default)."
421
+ )
422
+ kwargs["inplace"] = True
423
+
424
+ def fillna(self, *args, **kwargs):
425
+ FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs)
426
+ for ds in self.datasets:
427
+ ds.features.fillna(*args, **kwargs)
428
+
429
+ def replace(self, *args, **kwargs):
430
+ FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs)
431
+ for ds in self.datasets:
432
+ ds.features.replace(*args, **kwargs)
433
+
434
+ def interpolate(self, *args, **kwargs):
435
+ FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs)
436
+ for ds in self.datasets:
437
+ ds.features.interpolate(*args, **kwargs)
438
+
439
+ def dropna(self, *args, **kwargs):
440
+ FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs)
441
+ for ds in self.datasets:
442
+ ds.features.dropna(*args, **kwargs)
443
+
444
+ def drop(self, *args, **kwargs):
445
+ FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs)
446
+ for ds in self.datasets:
447
+ ds.features.drop(*args, **kwargs)
448
+
449
+ def join(self, concat_dataset: FeaturesConcatDataset, **kwargs):
450
+ assert len(self.datasets) == len(concat_dataset.datasets)
451
+ for ds1, ds2 in zip(self.datasets, concat_dataset.datasets):
452
+ assert len(ds1) == len(ds2)
453
+ ds1.features.join(ds2, **kwargs)