eegdash 0.4.0.dev173498563__py3-none-any.whl → 0.4.1.dev185__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,15 +1,48 @@
1
+ from __future__ import annotations
2
+
1
3
  import inspect
2
4
  from collections.abc import Callable
3
5
 
4
6
  from . import extractors, feature_bank
5
- from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
7
+ from .extractors import _get_underlying_func
8
+
9
+ __all__ = [
10
+ "get_all_feature_extractors",
11
+ "get_all_feature_kinds",
12
+ "get_all_features",
13
+ "get_feature_kind",
14
+ "get_feature_predecessors",
15
+ ]
16
+
17
+
18
+ def get_feature_predecessors(feature_or_extractor: Callable) -> list:
19
+ """Get the dependency hierarchy for a feature or feature extractor.
20
+
21
+ This function recursively traverses the `parent_extractor_type` attribute
22
+ of a feature or extractor to build a list representing its dependency
23
+ lineage.
24
+
25
+ Parameters
26
+ ----------
27
+ feature_or_extractor : callable
28
+ The feature function or :class:`~eegdash.features.extractors.FeatureExtractor`
29
+ class to inspect.
6
30
 
31
+ Returns
32
+ -------
33
+ list
34
+ A nested list representing the dependency tree. For a simple linear
35
+ chain, this will be a flat list from the specific feature up to the
36
+ base :class:`~eegdash.features.extractors.FeatureExtractor`. For
37
+ multiple dependencies, it will contain tuples of sub-dependencies.
7
38
 
8
- def get_feature_predecessors(feature_or_extractor: Callable):
39
+ """
9
40
  current = _get_underlying_func(feature_or_extractor)
10
- if current is FeatureExtractor:
41
+ if current is extractors.FeatureExtractor:
11
42
  return [current]
12
- predecessor = getattr(current, "parent_extractor_type", [FeatureExtractor])
43
+ predecessor = getattr(
44
+ current, "parent_extractor_type", [extractors.FeatureExtractor]
45
+ )
13
46
  if len(predecessor) == 1:
14
47
  return [current, *get_feature_predecessors(predecessor[0])]
15
48
  else:
@@ -20,29 +53,82 @@ def get_feature_predecessors(feature_or_extractor: Callable):
20
53
  return [current, tuple(predecessors)]
21
54
 
22
55
 
23
- def get_feature_kind(feature: Callable):
56
+ def get_feature_kind(feature: Callable) -> extractors.MultivariateFeature:
57
+ """Get the 'kind' of a feature function.
58
+
59
+ The feature kind (e.g., univariate, bivariate) is typically attached by a
60
+ decorator.
61
+
62
+ Parameters
63
+ ----------
64
+ feature : callable
65
+ The feature function to inspect.
66
+
67
+ Returns
68
+ -------
69
+ :class:`~eegdash.features.extractors.MultivariateFeature`
70
+ An instance of the feature kind (e.g., ``UnivariateFeature()``).
71
+
72
+ """
24
73
  return _get_underlying_func(feature).feature_kind
25
74
 
26
75
 
27
- def get_all_features():
76
+ def get_all_features() -> list[tuple[str, Callable]]:
77
+ """Get a list of all available feature functions.
78
+
79
+ Scans the `eegdash.features.feature_bank` module for functions that have
80
+ been decorated to have a `feature_kind` attribute.
81
+
82
+ Returns
83
+ -------
84
+ list[tuple[str, callable]]
85
+ A list of (name, function) tuples for all discovered features.
86
+
87
+ """
88
+
28
89
  def isfeature(x):
29
90
  return hasattr(_get_underlying_func(x), "feature_kind")
30
91
 
31
92
  return inspect.getmembers(feature_bank, isfeature)
32
93
 
33
94
 
34
- def get_all_feature_extractors():
95
+ def get_all_feature_extractors() -> list[tuple[str, type[extractors.FeatureExtractor]]]:
96
+ """Get a list of all available :class:`~eegdash.features.extractors.FeatureExtractor` classes.
97
+
98
+ Scans the `eegdash.features.feature_bank` module for all classes that
99
+ subclass :class:`~eegdash.features.extractors.FeatureExtractor`.
100
+
101
+ Returns
102
+ -------
103
+ list[tuple[str, type[eegdash.features.extractors.FeatureExtractor]]]
104
+ A list of (name, class) tuples for all discovered feature extractors,
105
+ including the base :class:`~eegdash.features.extractors.FeatureExtractor` itself.
106
+
107
+ """
108
+
35
109
  def isfeatureextractor(x):
36
- return inspect.isclass(x) and issubclass(x, FeatureExtractor)
110
+ return inspect.isclass(x) and issubclass(x, extractors.FeatureExtractor)
37
111
 
38
112
  return [
39
- ("FeatureExtractor", FeatureExtractor),
113
+ ("FeatureExtractor", extractors.FeatureExtractor),
40
114
  *inspect.getmembers(feature_bank, isfeatureextractor),
41
115
  ]
42
116
 
43
117
 
44
- def get_all_feature_kinds():
118
+ def get_all_feature_kinds() -> list[tuple[str, type[extractors.MultivariateFeature]]]:
119
+ """Get a list of all available feature 'kind' classes.
120
+
121
+ Scans the `eegdash.features.extractors` module for all classes that
122
+ subclass :class:`~eegdash.features.extractors.MultivariateFeature`.
123
+
124
+ Returns
125
+ -------
126
+ list[tuple[str, type[eegdash.features.extractors.MultivariateFeature]]]
127
+ A list of (name, class) tuples for all discovered feature kinds.
128
+
129
+ """
130
+
45
131
  def isfeaturekind(x):
46
- return inspect.isclass(x) and issubclass(x, MultivariateFeature)
132
+ return inspect.isclass(x) and issubclass(x, extractors.MultivariateFeature)
47
133
 
48
134
  return inspect.getmembers(extractors, isfeaturekind)
@@ -1,10 +1,13 @@
1
1
  """Convenience functions for storing and loading features datasets.
2
2
 
3
- See Also:
4
- https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
3
+ See Also
4
+ --------
5
+ https://github.com/braindecode/braindecode/blob/master/braindecode/datautil/serialization.py#L165-L229
5
6
 
6
7
  """
7
8
 
9
+ from __future__ import annotations
10
+
8
11
  from pathlib import Path
9
12
 
10
13
  import pandas as pd
@@ -15,35 +18,48 @@ from braindecode.datautil.serialization import _load_kwargs_json
15
18
 
16
19
  from .datasets import FeaturesConcatDataset, FeaturesDataset
17
20
 
21
+ __all__ = [
22
+ "load_features_concat_dataset",
23
+ ]
24
+
25
+
26
+ def load_features_concat_dataset(
27
+ path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1
28
+ ) -> FeaturesConcatDataset:
29
+ """Load a stored :class:`~eegdash.features.datasets.FeaturesConcatDataset` from a directory.
18
30
 
19
- def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
20
- """Load a stored features dataset from files.
31
+ This function reconstructs a
32
+ :class:`~eegdash.features.datasets.FeaturesConcatDataset` by loading
33
+ individual :class:`~eegdash.features.datasets.FeaturesDataset` instances
34
+ from subdirectories within the given path. It uses joblib for parallel
35
+ loading.
21
36
 
22
37
  Parameters
23
38
  ----------
24
- path: str | pathlib.Path
25
- Path to the directory of the .fif / -epo.fif and .json files.
26
- ids_to_load: list of int | None
27
- Ids of specific files to load.
28
- n_jobs: int
29
- Number of jobs to be used to read files in parallel.
39
+ path : str or pathlib.Path
40
+ The path to the directory where the dataset was saved. This directory
41
+ should contain subdirectories (e.g., "0", "1", "2", ...) for each
42
+ individual dataset.
43
+ ids_to_load : list of int, optional
44
+ A list of specific dataset IDs (subdirectory names) to load. If None,
45
+ all subdirectories in the path will be loaded.
46
+ n_jobs : int, default 1
47
+ The number of jobs to use for parallel loading. -1 means using all
48
+ processors.
30
49
 
31
50
  Returns
32
51
  -------
33
- concat_dataset: eegdash.features.datasets.FeaturesConcatDataset
34
- A concatenation of multiple eegdash.features.datasets.FeaturesDataset
35
- instances loaded from the given directory.
52
+ eegdash.features.datasets.FeaturesConcatDataset
53
+ A concatenated dataset containing the loaded
54
+ :class:`~eegdash.features.datasets.FeaturesDataset` instances.
36
55
 
37
56
  """
38
57
  # Make sure we always work with a pathlib.Path
39
58
  path = Path(path)
40
59
 
41
- # else we have a dataset saved in the new way with subdirectories in path
42
- # for every dataset with description.json and -feat.parquet,
43
- # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
44
- # window_preproc_kwargs.json, features_kwargs.json
45
60
  if ids_to_load is None:
46
- ids_to_load = [p.name for p in path.iterdir()]
61
+ # Get all subdirectories and sort them numerically
62
+ ids_to_load = [p.name for p in path.iterdir() if p.is_dir()]
47
63
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
48
64
  ids_to_load = [str(i) for i in ids_to_load]
49
65
 
@@ -51,7 +67,28 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
51
67
  return FeaturesConcatDataset(datasets)
52
68
 
53
69
 
54
- def _load_parallel(path, i):
70
+ def _load_parallel(path: Path, i: str) -> FeaturesDataset:
71
+ """Load a single :class:`~eegdash.features.datasets.FeaturesDataset` from its subdirectory.
72
+
73
+ This is a helper function for
74
+ :func:`~eegdash.features.serialization.load_features_concat_dataset` that
75
+ handles the loading of one dataset's files (features, metadata, descriptions, etc.).
76
+
77
+ Parameters
78
+ ----------
79
+ path : pathlib.Path
80
+ The root directory of the saved
81
+ :class:`~eegdash.features.datasets.FeaturesConcatDataset`.
82
+ i : str
83
+ The identifier of the dataset to load, corresponding to its
84
+ subdirectory name.
85
+
86
+ Returns
87
+ -------
88
+ eegdash.features.datasets.FeaturesDataset
89
+ The loaded dataset instance.
90
+
91
+ """
55
92
  sub_dir = path / i
56
93
 
57
94
  parquet_name_pattern = "{}-feat.parquet"
eegdash/features/utils.py CHANGED
@@ -14,15 +14,41 @@ from braindecode.datasets.base import (
14
14
  WindowsDataset,
15
15
  )
16
16
 
17
+ from . import extractors
17
18
  from .datasets import FeaturesConcatDataset, FeaturesDataset
18
- from .extractors import FeatureExtractor
19
+
20
+ __all__ = [
21
+ "extract_features",
22
+ "fit_feature_extractors",
23
+ ]
19
24
 
20
25
 
21
26
  def _extract_features_from_windowsdataset(
22
27
  win_ds: EEGWindowsDataset | WindowsDataset,
23
- feature_extractor: FeatureExtractor,
28
+ feature_extractor: extractors.FeatureExtractor,
24
29
  batch_size: int = 512,
25
- ):
30
+ ) -> FeaturesDataset:
31
+ """Extract features from a single `WindowsDataset`.
32
+
33
+ This is a helper function that iterates through a `WindowsDataset` in
34
+ batches, applies a `FeatureExtractor`, and returns the results as a
35
+ `FeaturesDataset`.
36
+
37
+ Parameters
38
+ ----------
39
+ win_ds : EEGWindowsDataset or WindowsDataset
40
+ The windowed dataset to extract features from.
41
+ feature_extractor : ~eegdash.features.extractors.FeatureExtractor
42
+ The feature extractor instance to apply.
43
+ batch_size : int, default 512
44
+ The number of windows to process in each batch.
45
+
46
+ Returns
47
+ -------
48
+ ~eegdash.features.datasets.FeaturesDataset
49
+ A new dataset containing the extracted features and associated metadata.
50
+
51
+ """
26
52
  metadata = win_ds.metadata
27
53
  if not win_ds.targets_from == "metadata":
28
54
  metadata = copy.deepcopy(metadata)
@@ -51,33 +77,59 @@ def _extract_features_from_windowsdataset(
51
77
  features_dict[k].extend(v)
52
78
  features_df = pd.DataFrame(features_dict)
53
79
  if not win_ds.targets_from == "metadata":
54
- metadata.set_index("orig_index", drop=False, inplace=True)
55
80
  metadata.reset_index(drop=True, inplace=True)
56
- metadata.drop("orig_index", axis=1, inplace=True)
81
+ metadata.drop("orig_index", axis=1, inplace=True, errors="ignore")
57
82
 
58
- # FUTURE: truly support WindowsDataset objects
59
83
  return FeaturesDataset(
60
84
  features_df,
61
85
  metadata=metadata,
62
86
  description=win_ds.description,
63
87
  raw_info=win_ds.raw.info,
64
- raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
65
- window_kwargs=win_ds.window_kwargs,
88
+ raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None),
89
+ window_kwargs=getattr(win_ds, "window_kwargs", None),
66
90
  features_kwargs=feature_extractor.features_kwargs,
67
91
  )
68
92
 
69
93
 
70
94
  def extract_features(
71
95
  concat_dataset: BaseConcatDataset,
72
- features: FeatureExtractor | Dict[str, Callable] | List[Callable],
96
+ features: extractors.FeatureExtractor | Dict[str, Callable] | List[Callable],
73
97
  *,
74
98
  batch_size: int = 512,
75
99
  n_jobs: int = 1,
76
- ):
100
+ ) -> FeaturesConcatDataset:
101
+ """Extract features from a concatenated dataset of windows.
102
+
103
+ This function applies a feature extractor to each `WindowsDataset` within a
104
+ `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset`
105
+ with the results.
106
+
107
+ Parameters
108
+ ----------
109
+ concat_dataset : BaseConcatDataset
110
+ A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset`
111
+ instances.
112
+ features : ~eegdash.features.extractors.FeatureExtractor or dict or list
113
+ The feature extractor(s) to apply. Can be a
114
+ :class:`~eegdash.features.extractors.FeatureExtractor`
115
+ instance, a dictionary of named feature functions, or a list of
116
+ feature functions.
117
+ batch_size : int, default 512
118
+ The size of batches to use for feature extraction.
119
+ n_jobs : int, default 1
120
+ The number of parallel jobs to use for extracting features from the
121
+ datasets.
122
+
123
+ Returns
124
+ -------
125
+ ~eegdash.features.datasets.FeaturesConcatDataset
126
+ A new concatenated dataset containing the extracted features.
127
+
128
+ """
77
129
  if isinstance(features, list):
78
130
  features = dict(enumerate(features))
79
- if not isinstance(features, FeatureExtractor):
80
- features = FeatureExtractor(features)
131
+ if not isinstance(features, extractors.FeatureExtractor):
132
+ features = extractors.FeatureExtractor(features)
81
133
  feature_ds_list = list(
82
134
  tqdm(
83
135
  Parallel(n_jobs=n_jobs, return_as="generator")(
@@ -95,13 +147,35 @@ def extract_features(
95
147
 
96
148
  def fit_feature_extractors(
97
149
  concat_dataset: BaseConcatDataset,
98
- features: FeatureExtractor | Dict[str, Callable] | List[Callable],
150
+ features: extractors.FeatureExtractor | Dict[str, Callable] | List[Callable],
99
151
  batch_size: int = 8192,
100
- ):
152
+ ) -> extractors.FeatureExtractor:
153
+ """Fit trainable feature extractors on a dataset.
154
+
155
+ If the provided feature extractor (or any of its sub-extractors) is
156
+ trainable (i.e., subclasses
157
+ :class:`~eegdash.features.extractors.TrainableFeature`), this function
158
+ iterates through the dataset to fit it.
159
+
160
+ Parameters
161
+ ----------
162
+ concat_dataset : BaseConcatDataset
163
+ The dataset to use for fitting the feature extractors.
164
+ features : ~eegdash.features.extractors.FeatureExtractor or dict or list
165
+ The feature extractor(s) to fit.
166
+ batch_size : int, default 8192
167
+ The batch size to use when iterating through the dataset for fitting.
168
+
169
+ Returns
170
+ -------
171
+ ~eegdash.features.extractors.FeatureExtractor
172
+ The fitted feature extractor.
173
+
174
+ """
101
175
  if isinstance(features, list):
102
176
  features = dict(enumerate(features))
103
- if not isinstance(features, FeatureExtractor):
104
- features = FeatureExtractor(features)
177
+ if not isinstance(features, extractors.FeatureExtractor):
178
+ features = extractors.FeatureExtractor(features)
105
179
  if not features._is_trainable:
106
180
  return features
107
181
  features.clear()
@@ -18,27 +18,47 @@ from ..logging import logger
18
18
 
19
19
 
20
20
  class hbn_ec_ec_reannotation(Preprocessor):
21
- """Preprocessor to reannotate the raw data for eyes open and eyes closed events.
21
+ """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events.
22
22
 
23
- This processor is designed for HBN datasets.
23
+ This preprocessor is specifically designed for Healthy Brain Network (HBN)
24
+ datasets. It identifies existing annotations for "instructed_toCloseEyes"
25
+ and "instructed_toOpenEyes" and creates new, regularly spaced annotations
26
+ for "eyes_closed" and "eyes_open" segments, respectively.
27
+
28
+ This is useful for creating windowed datasets based on these new, more
29
+ precise event markers.
30
+
31
+ Notes
32
+ -----
33
+ This class inherits from :class:`braindecode.preprocessing.Preprocessor`
34
+ and is intended to be used within a braindecode preprocessing pipeline.
24
35
 
25
36
  """
26
37
 
27
38
  def __init__(self):
28
39
  super().__init__(fn=self.transform, apply_on_array=False)
29
40
 
30
- def transform(self, raw):
31
- """Reannotate the raw data to create new events for eyes open and eyes closed
41
+ def transform(self, raw: mne.io.Raw) -> mne.io.Raw:
42
+ """Create new annotations for eyes-open and eyes-closed periods.
32
43
 
33
- This function modifies the raw MNE object by creating new events based on
34
- the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes".
35
- It generates new events every 2 seconds within specified time ranges after
36
- the original events, and replaces the existing annotations with these new events.
44
+ This function finds the original "instructed_to..." annotations and
45
+ generates new annotations every 2 seconds within specific time ranges
46
+ relative to the original markers:
47
+ - "eyes_closed": 15s to 29s after "instructed_toCloseEyes"
48
+ - "eyes_open": 5s to 19s after "instructed_toOpenEyes"
49
+
50
+ The original annotations in the `mne.io.Raw` object are replaced by
51
+ this new set of annotations.
37
52
 
38
53
  Parameters
39
54
  ----------
40
55
  raw : mne.io.Raw
41
- The raw MNE object containing EEG data and annotations.
56
+ The raw MNE object containing the HBN data and original annotations.
57
+
58
+ Returns
59
+ -------
60
+ mne.io.Raw
61
+ The raw MNE object with the modified annotations.
42
62
 
43
63
  """
44
64
  events, event_id = mne.events_from_annotations(raw)
@@ -48,15 +68,27 @@ class hbn_ec_ec_reannotation(Preprocessor):
48
68
  # Create new events array for 2-second segments
49
69
  new_events = []
50
70
  sfreq = raw.info["sfreq"]
51
- for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]:
52
- # For each original event, create events every 2 seconds from 15s to 29s after
53
- start_times = event[0] + np.arange(15, 29, 2) * sfreq
54
- new_events.extend([[int(t), 0, 1] for t in start_times])
55
71
 
56
- for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]:
57
- # For each original event, create events every 2 seconds from 5s to 19s after
58
- start_times = event[0] + np.arange(5, 19, 2) * sfreq
59
- new_events.extend([[int(t), 0, 2] for t in start_times])
72
+ close_event_id = event_id.get("instructed_toCloseEyes")
73
+ if close_event_id:
74
+ for event in events[events[:, 2] == close_event_id]:
75
+ # For each original event, create events every 2s from 15s to 29s after
76
+ start_times = event[0] + np.arange(15, 29, 2) * sfreq
77
+ new_events.extend([[int(t), 0, 1] for t in start_times])
78
+
79
+ open_event_id = event_id.get("instructed_toOpenEyes")
80
+ if open_event_id:
81
+ for event in events[events[:, 2] == open_event_id]:
82
+ # For each original event, create events every 2s from 5s to 19s after
83
+ start_times = event[0] + np.arange(5, 19, 2) * sfreq
84
+ new_events.extend([[int(t), 0, 2] for t in start_times])
85
+
86
+ if not new_events:
87
+ logger.warning(
88
+ "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' "
89
+ "annotations. No new events created."
90
+ )
91
+ return raw
60
92
 
61
93
  # replace events in raw
62
94
  new_events = np.array(new_events)
@@ -65,6 +97,7 @@ class hbn_ec_ec_reannotation(Preprocessor):
65
97
  events=new_events,
66
98
  event_desc={1: "eyes_closed", 2: "eyes_open"},
67
99
  sfreq=raw.info["sfreq"],
100
+ orig_time=raw.info.get("meas_date"),
68
101
  )
69
102
 
70
103
  raw.set_annotations(annot_from_events)