eegdash 0.4.0.dev173498563__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,11 +1,41 @@
1
+ from __future__ import annotations
2
+
1
3
  import inspect
2
4
  from collections.abc import Callable
3
5
 
4
6
  from . import extractors, feature_bank
5
7
  from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
6
8
 
9
+ __all__ = [
10
+ "get_all_feature_extractors",
11
+ "get_all_feature_kinds",
12
+ "get_all_features",
13
+ "get_feature_kind",
14
+ "get_feature_predecessors",
15
+ ]
16
+
17
+
18
+ def get_feature_predecessors(feature_or_extractor: Callable) -> list:
19
+ """Get the dependency hierarchy for a feature or feature extractor.
20
+
21
+ This function recursively traverses the `parent_extractor_type` attribute
22
+ of a feature or extractor to build a list representing its dependency
23
+ lineage.
24
+
25
+ Parameters
26
+ ----------
27
+ feature_or_extractor : callable
28
+ The feature function or :class:`FeatureExtractor` class to inspect.
29
+
30
+ Returns
31
+ -------
32
+ list
33
+ A nested list representing the dependency tree. For a simple linear
34
+ chain, this will be a flat list from the specific feature up to the
35
+ base `FeatureExtractor`. For multiple dependencies, it will contain
36
+ tuples of sub-dependencies.
7
37
 
8
- def get_feature_predecessors(feature_or_extractor: Callable):
38
+ """
9
39
  current = _get_underlying_func(feature_or_extractor)
10
40
  if current is FeatureExtractor:
11
41
  return [current]
@@ -20,18 +50,59 @@ def get_feature_predecessors(feature_or_extractor: Callable):
20
50
  return [current, tuple(predecessors)]
21
51
 
22
52
 
23
- def get_feature_kind(feature: Callable):
53
+ def get_feature_kind(feature: Callable) -> MultivariateFeature:
54
+ """Get the 'kind' of a feature function.
55
+
56
+ The feature kind (e.g., univariate, bivariate) is typically attached by a
57
+ decorator.
58
+
59
+ Parameters
60
+ ----------
61
+ feature : callable
62
+ The feature function to inspect.
63
+
64
+ Returns
65
+ -------
66
+ MultivariateFeature
67
+ An instance of the feature kind (e.g., `UnivariateFeature()`).
68
+
69
+ """
24
70
  return _get_underlying_func(feature).feature_kind
25
71
 
26
72
 
27
- def get_all_features():
73
+ def get_all_features() -> list[tuple[str, Callable]]:
74
+ """Get a list of all available feature functions.
75
+
76
+ Scans the `eegdash.features.feature_bank` module for functions that have
77
+ been decorated to have a `feature_kind` attribute.
78
+
79
+ Returns
80
+ -------
81
+ list[tuple[str, callable]]
82
+ A list of (name, function) tuples for all discovered features.
83
+
84
+ """
85
+
28
86
  def isfeature(x):
29
87
  return hasattr(_get_underlying_func(x), "feature_kind")
30
88
 
31
89
  return inspect.getmembers(feature_bank, isfeature)
32
90
 
33
91
 
34
- def get_all_feature_extractors():
92
+ def get_all_feature_extractors() -> list[tuple[str, type[FeatureExtractor]]]:
93
+ """Get a list of all available `FeatureExtractor` classes.
94
+
95
+ Scans the `eegdash.features.feature_bank` module for all classes that
96
+ subclass :class:`~eegdash.features.extractors.FeatureExtractor`.
97
+
98
+ Returns
99
+ -------
100
+ list[tuple[str, type[FeatureExtractor]]]
101
+ A list of (name, class) tuples for all discovered feature extractors,
102
+ including the base `FeatureExtractor` itself.
103
+
104
+ """
105
+
35
106
  def isfeatureextractor(x):
36
107
  return inspect.isclass(x) and issubclass(x, FeatureExtractor)
37
108
 
@@ -41,7 +112,19 @@ def get_all_feature_extractors():
41
112
  ]
42
113
 
43
114
 
44
- def get_all_feature_kinds():
115
+ def get_all_feature_kinds() -> list[tuple[str, type[MultivariateFeature]]]:
116
+ """Get a list of all available feature 'kind' classes.
117
+
118
+ Scans the `eegdash.features.extractors` module for all classes that
119
+ subclass :class:`~eegdash.features.extractors.MultivariateFeature`.
120
+
121
+ Returns
122
+ -------
123
+ list[tuple[str, type[MultivariateFeature]]]
124
+ A list of (name, class) tuples for all discovered feature kinds.
125
+
126
+ """
127
+
45
128
  def isfeaturekind(x):
46
129
  return inspect.isclass(x) and issubclass(x, MultivariateFeature)
47
130
 
@@ -1,10 +1,13 @@
1
1
  """Convenience functions for storing and loading features datasets.
2
2
 
3
- See Also:
4
- https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
3
+ See Also
4
+ --------
5
+ https://github.com/braindecode/braindecode/blob/master/braindecode/datautil/serialization.py#L165-L229
5
6
 
6
7
  """
7
8
 
9
+ from __future__ import annotations
10
+
8
11
  from pathlib import Path
9
12
 
10
13
  import pandas as pd
@@ -15,35 +18,45 @@ from braindecode.datautil.serialization import _load_kwargs_json
15
18
 
16
19
  from .datasets import FeaturesConcatDataset, FeaturesDataset
17
20
 
21
+ __all__ = [
22
+ "load_features_concat_dataset",
23
+ ]
24
+
25
+
26
+ def load_features_concat_dataset(
27
+ path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1
28
+ ) -> FeaturesConcatDataset:
29
+ """Load a stored `FeaturesConcatDataset` from a directory.
18
30
 
19
- def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
20
- """Load a stored features dataset from files.
31
+ This function reconstructs a :class:`FeaturesConcatDataset` by loading
32
+ individual :class:`FeaturesDataset` instances from subdirectories within
33
+ the given path. It uses joblib for parallel loading.
21
34
 
22
35
  Parameters
23
36
  ----------
24
- path: str | pathlib.Path
25
- Path to the directory of the .fif / -epo.fif and .json files.
26
- ids_to_load: list of int | None
27
- Ids of specific files to load.
28
- n_jobs: int
29
- Number of jobs to be used to read files in parallel.
37
+ path : str or pathlib.Path
38
+ The path to the directory where the dataset was saved. This directory
39
+ should contain subdirectories (e.g., "0", "1", "2", ...) for each
40
+ individual dataset.
41
+ ids_to_load : list of int, optional
42
+ A list of specific dataset IDs (subdirectory names) to load. If None,
43
+ all subdirectories in the path will be loaded.
44
+ n_jobs : int, default 1
45
+ The number of jobs to use for parallel loading. -1 means using all
46
+ processors.
30
47
 
31
48
  Returns
32
49
  -------
33
- concat_dataset: eegdash.features.datasets.FeaturesConcatDataset
34
- A concatenation of multiple eegdash.features.datasets.FeaturesDataset
35
- instances loaded from the given directory.
50
+ eegdash.features.datasets.FeaturesConcatDataset
51
+ A concatenated dataset containing the loaded `FeaturesDataset` instances.
36
52
 
37
53
  """
38
54
  # Make sure we always work with a pathlib.Path
39
55
  path = Path(path)
40
56
 
41
- # else we have a dataset saved in the new way with subdirectories in path
42
- # for every dataset with description.json and -feat.parquet,
43
- # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
44
- # window_preproc_kwargs.json, features_kwargs.json
45
57
  if ids_to_load is None:
46
- ids_to_load = [p.name for p in path.iterdir()]
58
+ # Get all subdirectories and sort them numerically
59
+ ids_to_load = [p.name for p in path.iterdir() if p.is_dir()]
47
60
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
48
61
  ids_to_load = [str(i) for i in ids_to_load]
49
62
 
@@ -51,7 +64,26 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
51
64
  return FeaturesConcatDataset(datasets)
52
65
 
53
66
 
54
- def _load_parallel(path, i):
67
+ def _load_parallel(path: Path, i: str) -> FeaturesDataset:
68
+ """Load a single `FeaturesDataset` from its subdirectory.
69
+
70
+ This is a helper function for `load_features_concat_dataset` that handles
71
+ the loading of one dataset's files (features, metadata, descriptions, etc.).
72
+
73
+ Parameters
74
+ ----------
75
+ path : pathlib.Path
76
+ The root directory of the saved `FeaturesConcatDataset`.
77
+ i : str
78
+ The identifier of the dataset to load, corresponding to its
79
+ subdirectory name.
80
+
81
+ Returns
82
+ -------
83
+ eegdash.features.datasets.FeaturesDataset
84
+ The loaded dataset instance.
85
+
86
+ """
55
87
  sub_dir = path / i
56
88
 
57
89
  parquet_name_pattern = "{}-feat.parquet"
eegdash/features/utils.py CHANGED
@@ -17,12 +17,38 @@ from braindecode.datasets.base import (
17
17
  from .datasets import FeaturesConcatDataset, FeaturesDataset
18
18
  from .extractors import FeatureExtractor
19
19
 
20
+ __all__ = [
21
+ "extract_features",
22
+ "fit_feature_extractors",
23
+ ]
24
+
20
25
 
21
26
  def _extract_features_from_windowsdataset(
22
27
  win_ds: EEGWindowsDataset | WindowsDataset,
23
28
  feature_extractor: FeatureExtractor,
24
29
  batch_size: int = 512,
25
- ):
30
+ ) -> FeaturesDataset:
31
+ """Extract features from a single `WindowsDataset`.
32
+
33
+ This is a helper function that iterates through a `WindowsDataset` in
34
+ batches, applies a `FeatureExtractor`, and returns the results as a
35
+ `FeaturesDataset`.
36
+
37
+ Parameters
38
+ ----------
39
+ win_ds : EEGWindowsDataset or WindowsDataset
40
+ The windowed dataset to extract features from.
41
+ feature_extractor : FeatureExtractor
42
+ The feature extractor instance to apply.
43
+ batch_size : int, default 512
44
+ The number of windows to process in each batch.
45
+
46
+ Returns
47
+ -------
48
+ FeaturesDataset
49
+ A new dataset containing the extracted features and associated metadata.
50
+
51
+ """
26
52
  metadata = win_ds.metadata
27
53
  if not win_ds.targets_from == "metadata":
28
54
  metadata = copy.deepcopy(metadata)
@@ -51,18 +77,16 @@ def _extract_features_from_windowsdataset(
51
77
  features_dict[k].extend(v)
52
78
  features_df = pd.DataFrame(features_dict)
53
79
  if not win_ds.targets_from == "metadata":
54
- metadata.set_index("orig_index", drop=False, inplace=True)
55
80
  metadata.reset_index(drop=True, inplace=True)
56
- metadata.drop("orig_index", axis=1, inplace=True)
81
+ metadata.drop("orig_index", axis=1, inplace=True, errors="ignore")
57
82
 
58
- # FUTURE: truly support WindowsDataset objects
59
83
  return FeaturesDataset(
60
84
  features_df,
61
85
  metadata=metadata,
62
86
  description=win_ds.description,
63
87
  raw_info=win_ds.raw.info,
64
- raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
65
- window_kwargs=win_ds.window_kwargs,
88
+ raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None),
89
+ window_kwargs=getattr(win_ds, "window_kwargs", None),
66
90
  features_kwargs=feature_extractor.features_kwargs,
67
91
  )
68
92
 
@@ -73,7 +97,34 @@ def extract_features(
73
97
  *,
74
98
  batch_size: int = 512,
75
99
  n_jobs: int = 1,
76
- ):
100
+ ) -> FeaturesConcatDataset:
101
+ """Extract features from a concatenated dataset of windows.
102
+
103
+ This function applies a feature extractor to each `WindowsDataset` within a
104
+ `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset`
105
+ with the results.
106
+
107
+ Parameters
108
+ ----------
109
+ concat_dataset : BaseConcatDataset
110
+ A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset`
111
+ instances.
112
+ features : FeatureExtractor or dict or list
113
+ The feature extractor(s) to apply. Can be a `FeatureExtractor`
114
+ instance, a dictionary of named feature functions, or a list of
115
+ feature functions.
116
+ batch_size : int, default 512
117
+ The size of batches to use for feature extraction.
118
+ n_jobs : int, default 1
119
+ The number of parallel jobs to use for extracting features from the
120
+ datasets.
121
+
122
+ Returns
123
+ -------
124
+ FeaturesConcatDataset
125
+ A new concatenated dataset containing the extracted features.
126
+
127
+ """
77
128
  if isinstance(features, list):
78
129
  features = dict(enumerate(features))
79
130
  if not isinstance(features, FeatureExtractor):
@@ -97,7 +148,28 @@ def fit_feature_extractors(
97
148
  concat_dataset: BaseConcatDataset,
98
149
  features: FeatureExtractor | Dict[str, Callable] | List[Callable],
99
150
  batch_size: int = 8192,
100
- ):
151
+ ) -> FeatureExtractor:
152
+ """Fit trainable feature extractors on a dataset.
153
+
154
+ If the provided feature extractor (or any of its sub-extractors) is
155
+ trainable (i.e., subclasses `TrainableFeature`), this function iterates
156
+ through the dataset to fit it.
157
+
158
+ Parameters
159
+ ----------
160
+ concat_dataset : BaseConcatDataset
161
+ The dataset to use for fitting the feature extractors.
162
+ features : FeatureExtractor or dict or list
163
+ The feature extractor(s) to fit.
164
+ batch_size : int, default 8192
165
+ The batch size to use when iterating through the dataset for fitting.
166
+
167
+ Returns
168
+ -------
169
+ FeatureExtractor
170
+ The fitted feature extractor.
171
+
172
+ """
101
173
  if isinstance(features, list):
102
174
  features = dict(enumerate(features))
103
175
  if not isinstance(features, FeatureExtractor):
@@ -18,27 +18,47 @@ from ..logging import logger
18
18
 
19
19
 
20
20
  class hbn_ec_ec_reannotation(Preprocessor):
21
- """Preprocessor to reannotate the raw data for eyes open and eyes closed events.
21
+ """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events.
22
22
 
23
- This processor is designed for HBN datasets.
23
+ This preprocessor is specifically designed for Healthy Brain Network (HBN)
24
+ datasets. It identifies existing annotations for "instructed_toCloseEyes"
25
+ and "instructed_toOpenEyes" and creates new, regularly spaced annotations
26
+ for "eyes_closed" and "eyes_open" segments, respectively.
27
+
28
+ This is useful for creating windowed datasets based on these new, more
29
+ precise event markers.
30
+
31
+ Notes
32
+ -----
33
+ This class inherits from :class:`braindecode.preprocessing.Preprocessor`
34
+ and is intended to be used within a braindecode preprocessing pipeline.
24
35
 
25
36
  """
26
37
 
27
38
  def __init__(self):
28
39
  super().__init__(fn=self.transform, apply_on_array=False)
29
40
 
30
- def transform(self, raw):
31
- """Reannotate the raw data to create new events for eyes open and eyes closed
41
+ def transform(self, raw: mne.io.Raw) -> mne.io.Raw:
42
+ """Create new annotations for eyes-open and eyes-closed periods.
32
43
 
33
- This function modifies the raw MNE object by creating new events based on
34
- the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes".
35
- It generates new events every 2 seconds within specified time ranges after
36
- the original events, and replaces the existing annotations with these new events.
44
+ This function finds the original "instructed_to..." annotations and
45
+ generates new annotations every 2 seconds within specific time ranges
46
+ relative to the original markers:
47
+ - "eyes_closed": 15s to 29s after "instructed_toCloseEyes"
48
+ - "eyes_open": 5s to 19s after "instructed_toOpenEyes"
49
+
50
+ The original annotations in the `mne.io.Raw` object are replaced by
51
+ this new set of annotations.
37
52
 
38
53
  Parameters
39
54
  ----------
40
55
  raw : mne.io.Raw
41
- The raw MNE object containing EEG data and annotations.
56
+ The raw MNE object containing the HBN data and original annotations.
57
+
58
+ Returns
59
+ -------
60
+ mne.io.Raw
61
+ The raw MNE object with the modified annotations.
42
62
 
43
63
  """
44
64
  events, event_id = mne.events_from_annotations(raw)
@@ -48,15 +68,27 @@ class hbn_ec_ec_reannotation(Preprocessor):
48
68
  # Create new events array for 2-second segments
49
69
  new_events = []
50
70
  sfreq = raw.info["sfreq"]
51
- for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]:
52
- # For each original event, create events every 2 seconds from 15s to 29s after
53
- start_times = event[0] + np.arange(15, 29, 2) * sfreq
54
- new_events.extend([[int(t), 0, 1] for t in start_times])
55
71
 
56
- for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]:
57
- # For each original event, create events every 2 seconds from 5s to 19s after
58
- start_times = event[0] + np.arange(5, 19, 2) * sfreq
59
- new_events.extend([[int(t), 0, 2] for t in start_times])
72
+ close_event_id = event_id.get("instructed_toCloseEyes")
73
+ if close_event_id:
74
+ for event in events[events[:, 2] == close_event_id]:
75
+ # For each original event, create events every 2s from 15s to 29s after
76
+ start_times = event[0] + np.arange(15, 29, 2) * sfreq
77
+ new_events.extend([[int(t), 0, 1] for t in start_times])
78
+
79
+ open_event_id = event_id.get("instructed_toOpenEyes")
80
+ if open_event_id:
81
+ for event in events[events[:, 2] == open_event_id]:
82
+ # For each original event, create events every 2s from 5s to 19s after
83
+ start_times = event[0] + np.arange(5, 19, 2) * sfreq
84
+ new_events.extend([[int(t), 0, 2] for t in start_times])
85
+
86
+ if not new_events:
87
+ logger.warning(
88
+ "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' "
89
+ "annotations. No new events created."
90
+ )
91
+ return raw
60
92
 
61
93
  # replace events in raw
62
94
  new_events = np.array(new_events)
@@ -65,6 +97,7 @@ class hbn_ec_ec_reannotation(Preprocessor):
65
97
  events=new_events,
66
98
  event_desc={1: "eyes_closed", 2: "eyes_open"},
67
99
  sfreq=raw.info["sfreq"],
100
+ orig_time=raw.info.get("meas_date"),
68
101
  )
69
102
 
70
103
  raw.set_annotations(annot_from_events)