eegdash 0.3.9.dev182388821__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,8 +1,13 @@
1
- """Convenience functions for storing and loading of features datasets.
1
+ """Convenience functions for storing and loading features datasets.
2
+
3
+ See Also
4
+ --------
5
+ https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
2
6
 
3
- see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
4
7
  """
5
8
 
9
+ from __future__ import annotations
10
+
6
11
  from pathlib import Path
7
12
 
8
13
  import pandas as pd
@@ -14,32 +19,40 @@ from braindecode.datautil.serialization import _load_kwargs_json
14
19
  from .datasets import FeaturesConcatDataset, FeaturesDataset
15
20
 
16
21
 
17
- def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
18
- """Load a stored FeaturesConcatDataset of FeaturesDatasets from files.
22
+ def load_features_concat_dataset(
23
+ path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1
24
+ ) -> FeaturesConcatDataset:
25
+ """Load a stored `FeaturesConcatDataset` from a directory.
26
+
27
+ This function reconstructs a :class:`FeaturesConcatDataset` by loading
28
+ individual :class:`FeaturesDataset` instances from subdirectories within
29
+ the given path. It uses joblib for parallel loading.
19
30
 
20
31
  Parameters
21
32
  ----------
22
- path: str | pathlib.Path
23
- Path to the directory of the .fif / -epo.fif and .json files.
24
- ids_to_load: list of int | None
25
- Ids of specific files to load.
26
- n_jobs: int
27
- Number of jobs to be used to read files in parallel.
33
+ path : str or pathlib.Path
34
+ The path to the directory where the dataset was saved. This directory
35
+ should contain subdirectories (e.g., "0", "1", "2", ...) for each
36
+ individual dataset.
37
+ ids_to_load : list of int, optional
38
+ A list of specific dataset IDs (subdirectory names) to load. If None,
39
+ all subdirectories in the path will be loaded.
40
+ n_jobs : int, default 1
41
+ The number of jobs to use for parallel loading. -1 means using all
42
+ processors.
28
43
 
29
44
  Returns
30
45
  -------
31
- concat_dataset: FeaturesConcatDataset of FeaturesDatasets
46
+ eegdash.features.datasets.FeaturesConcatDataset
47
+ A concatenated dataset containing the loaded `FeaturesDataset` instances.
32
48
 
33
49
  """
34
50
  # Make sure we always work with a pathlib.Path
35
51
  path = Path(path)
36
52
 
37
- # else we have a dataset saved in the new way with subdirectories in path
38
- # for every dataset with description.json and -feat.parquet,
39
- # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
40
- # window_preproc_kwargs.json, features_kwargs.json
41
53
  if ids_to_load is None:
42
- ids_to_load = [p.name for p in path.iterdir()]
54
+ # Get all subdirectories and sort them numerically
55
+ ids_to_load = [p.name for p in path.iterdir() if p.is_dir()]
43
56
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
44
57
  ids_to_load = [str(i) for i in ids_to_load]
45
58
 
@@ -47,7 +60,26 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
47
60
  return FeaturesConcatDataset(datasets)
48
61
 
49
62
 
50
- def _load_parallel(path, i):
63
+ def _load_parallel(path: Path, i: str) -> FeaturesDataset:
64
+ """Load a single `FeaturesDataset` from its subdirectory.
65
+
66
+ This is a helper function for `load_features_concat_dataset` that handles
67
+ the loading of one dataset's files (features, metadata, descriptions, etc.).
68
+
69
+ Parameters
70
+ ----------
71
+ path : pathlib.Path
72
+ The root directory of the saved `FeaturesConcatDataset`.
73
+ i : str
74
+ The identifier of the dataset to load, corresponding to its
75
+ subdirectory name.
76
+
77
+ Returns
78
+ -------
79
+ eegdash.features.datasets.FeaturesDataset
80
+ The loaded dataset instance.
81
+
82
+ """
51
83
  sub_dir = path / i
52
84
 
53
85
  parquet_name_pattern = "{}-feat.parquet"
eegdash/features/utils.py CHANGED
@@ -22,7 +22,28 @@ def _extract_features_from_windowsdataset(
22
22
  win_ds: EEGWindowsDataset | WindowsDataset,
23
23
  feature_extractor: FeatureExtractor,
24
24
  batch_size: int = 512,
25
- ):
25
+ ) -> FeaturesDataset:
26
+ """Extract features from a single `WindowsDataset`.
27
+
28
+ This is a helper function that iterates through a `WindowsDataset` in
29
+ batches, applies a `FeatureExtractor`, and returns the results as a
30
+ `FeaturesDataset`.
31
+
32
+ Parameters
33
+ ----------
34
+ win_ds : EEGWindowsDataset or WindowsDataset
35
+ The windowed dataset to extract features from.
36
+ feature_extractor : FeatureExtractor
37
+ The feature extractor instance to apply.
38
+ batch_size : int, default 512
39
+ The number of windows to process in each batch.
40
+
41
+ Returns
42
+ -------
43
+ FeaturesDataset
44
+ A new dataset containing the extracted features and associated metadata.
45
+
46
+ """
26
47
  metadata = win_ds.metadata
27
48
  if not win_ds.targets_from == "metadata":
28
49
  metadata = copy.deepcopy(metadata)
@@ -51,18 +72,16 @@ def _extract_features_from_windowsdataset(
51
72
  features_dict[k].extend(v)
52
73
  features_df = pd.DataFrame(features_dict)
53
74
  if not win_ds.targets_from == "metadata":
54
- metadata.set_index("orig_index", drop=False, inplace=True)
55
75
  metadata.reset_index(drop=True, inplace=True)
56
- metadata.drop("orig_index", axis=1, inplace=True)
76
+ metadata.drop("orig_index", axis=1, inplace=True, errors="ignore")
57
77
 
58
- # FUTURE: truly support WindowsDataset objects
59
78
  return FeaturesDataset(
60
79
  features_df,
61
80
  metadata=metadata,
62
81
  description=win_ds.description,
63
82
  raw_info=win_ds.raw.info,
64
- raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
65
- window_kwargs=win_ds.window_kwargs,
83
+ raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None),
84
+ window_kwargs=getattr(win_ds, "window_kwargs", None),
66
85
  features_kwargs=feature_extractor.features_kwargs,
67
86
  )
68
87
 
@@ -73,7 +92,34 @@ def extract_features(
73
92
  *,
74
93
  batch_size: int = 512,
75
94
  n_jobs: int = 1,
76
- ):
95
+ ) -> FeaturesConcatDataset:
96
+ """Extract features from a concatenated dataset of windows.
97
+
98
+ This function applies a feature extractor to each `WindowsDataset` within a
99
+ `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset`
100
+ with the results.
101
+
102
+ Parameters
103
+ ----------
104
+ concat_dataset : BaseConcatDataset
105
+ A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset`
106
+ instances.
107
+ features : FeatureExtractor or dict or list
108
+ The feature extractor(s) to apply. Can be a `FeatureExtractor`
109
+ instance, a dictionary of named feature functions, or a list of
110
+ feature functions.
111
+ batch_size : int, default 512
112
+ The size of batches to use for feature extraction.
113
+ n_jobs : int, default 1
114
+ The number of parallel jobs to use for extracting features from the
115
+ datasets.
116
+
117
+ Returns
118
+ -------
119
+ FeaturesConcatDataset
120
+ A new concatenated dataset containing the extracted features.
121
+
122
+ """
77
123
  if isinstance(features, list):
78
124
  features = dict(enumerate(features))
79
125
  if not isinstance(features, FeatureExtractor):
@@ -97,7 +143,28 @@ def fit_feature_extractors(
97
143
  concat_dataset: BaseConcatDataset,
98
144
  features: FeatureExtractor | Dict[str, Callable] | List[Callable],
99
145
  batch_size: int = 8192,
100
- ):
146
+ ) -> FeatureExtractor:
147
+ """Fit trainable feature extractors on a dataset.
148
+
149
+ If the provided feature extractor (or any of its sub-extractors) is
150
+ trainable (i.e., subclasses `TrainableFeature`), this function iterates
151
+ through the dataset to fit it.
152
+
153
+ Parameters
154
+ ----------
155
+ concat_dataset : BaseConcatDataset
156
+ The dataset to use for fitting the feature extractors.
157
+ features : FeatureExtractor or dict or list
158
+ The feature extractor(s) to fit.
159
+ batch_size : int, default 8192
160
+ The batch size to use when iterating through the dataset for fitting.
161
+
162
+ Returns
163
+ -------
164
+ FeatureExtractor
165
+ The fitted feature extractor.
166
+
167
+ """
101
168
  if isinstance(features, list):
102
169
  features = dict(enumerate(features))
103
170
  if not isinstance(features, FeatureExtractor):
eegdash/hbn/__init__.py CHANGED
@@ -1,3 +1,14 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: GNU General Public License
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Healthy Brain Network (HBN) specific utilities and preprocessing.
6
+
7
+ This module provides specialized functions for working with the Healthy Brain Network
8
+ dataset, including preprocessing pipelines, annotation handling, and windowing utilities
9
+ tailored for HBN EEG data analysis.
10
+ """
11
+
1
12
  from .preprocessing import hbn_ec_ec_reannotation
2
13
  from .windows import (
3
14
  add_aux_anchors,
@@ -1,35 +1,64 @@
1
- import logging
1
+ # Authors: The EEGDash contributors.
2
+ # License: GNU General Public License
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Preprocessing utilities specific to the Healthy Brain Network dataset.
6
+
7
+ This module contains preprocessing classes and functions designed specifically for
8
+ HBN EEG data, including specialized annotation handling for eyes-open/eyes-closed
9
+ paradigms and other HBN-specific preprocessing steps.
10
+ """
2
11
 
3
12
  import mne
4
13
  import numpy as np
5
14
 
6
15
  from braindecode.preprocessing import Preprocessor
7
16
 
8
- logger = logging.getLogger("eegdash")
17
+ from ..logging import logger
9
18
 
10
19
 
11
20
  class hbn_ec_ec_reannotation(Preprocessor):
12
- """Preprocessor to reannotate the raw data for eyes open and eyes closed events.
21
+ """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events.
22
+
23
+ This preprocessor is specifically designed for Healthy Brain Network (HBN)
24
+ datasets. It identifies existing annotations for "instructed_toCloseEyes"
25
+ and "instructed_toOpenEyes" and creates new, regularly spaced annotations
26
+ for "eyes_closed" and "eyes_open" segments, respectively.
13
27
 
14
- This processor is designed for HBN datasets.
28
+ This is useful for creating windowed datasets based on these new, more
29
+ precise event markers.
30
+
31
+ Notes
32
+ -----
33
+ This class inherits from :class:`braindecode.preprocessing.Preprocessor`
34
+ and is intended to be used within a braindecode preprocessing pipeline.
15
35
 
16
36
  """
17
37
 
18
38
  def __init__(self):
19
39
  super().__init__(fn=self.transform, apply_on_array=False)
20
40
 
21
- def transform(self, raw):
22
- """Reannotate the raw data to create new events for eyes open and eyes closed
41
+ def transform(self, raw: mne.io.Raw) -> mne.io.Raw:
42
+ """Create new annotations for eyes-open and eyes-closed periods.
43
+
44
+ This function finds the original "instructed_to..." annotations and
45
+ generates new annotations every 2 seconds within specific time ranges
46
+ relative to the original markers:
47
+ - "eyes_closed": 15s to 29s after "instructed_toCloseEyes"
48
+ - "eyes_open": 5s to 19s after "instructed_toOpenEyes"
23
49
 
24
- This function modifies the raw MNE object by creating new events based on
25
- the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes".
26
- It generates new events every 2 seconds within specified time ranges after
27
- the original events, and replaces the existing annotations with these new events.
50
+ The original annotations in the `mne.io.Raw` object are replaced by
51
+ this new set of annotations.
28
52
 
29
53
  Parameters
30
54
  ----------
31
55
  raw : mne.io.Raw
32
- The raw MNE object containing EEG data and annotations.
56
+ The raw MNE object containing the HBN data and original annotations.
57
+
58
+ Returns
59
+ -------
60
+ mne.io.Raw
61
+ The raw MNE object with the modified annotations.
33
62
 
34
63
  """
35
64
  events, event_id = mne.events_from_annotations(raw)
@@ -39,15 +68,27 @@ class hbn_ec_ec_reannotation(Preprocessor):
39
68
  # Create new events array for 2-second segments
40
69
  new_events = []
41
70
  sfreq = raw.info["sfreq"]
42
- for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]:
43
- # For each original event, create events every 2 seconds from 15s to 29s after
44
- start_times = event[0] + np.arange(15, 29, 2) * sfreq
45
- new_events.extend([[int(t), 0, 1] for t in start_times])
46
71
 
47
- for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]:
48
- # For each original event, create events every 2 seconds from 5s to 19s after
49
- start_times = event[0] + np.arange(5, 19, 2) * sfreq
50
- new_events.extend([[int(t), 0, 2] for t in start_times])
72
+ close_event_id = event_id.get("instructed_toCloseEyes")
73
+ if close_event_id:
74
+ for event in events[events[:, 2] == close_event_id]:
75
+ # For each original event, create events every 2s from 15s to 29s after
76
+ start_times = event[0] + np.arange(15, 29, 2) * sfreq
77
+ new_events.extend([[int(t), 0, 1] for t in start_times])
78
+
79
+ open_event_id = event_id.get("instructed_toOpenEyes")
80
+ if open_event_id:
81
+ for event in events[events[:, 2] == open_event_id]:
82
+ # For each original event, create events every 2s from 5s to 19s after
83
+ start_times = event[0] + np.arange(5, 19, 2) * sfreq
84
+ new_events.extend([[int(t), 0, 2] for t in start_times])
85
+
86
+ if not new_events:
87
+ logger.warning(
88
+ "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' "
89
+ "annotations. No new events created."
90
+ )
91
+ return raw
51
92
 
52
93
  # replace events in raw
53
94
  new_events = np.array(new_events)
@@ -56,6 +97,7 @@ class hbn_ec_ec_reannotation(Preprocessor):
56
97
  events=new_events,
57
98
  event_desc={1: "eyes_closed", 2: "eyes_open"},
58
99
  sfreq=raw.info["sfreq"],
100
+ orig_time=raw.info.get("meas_date"),
59
101
  )
60
102
 
61
103
  raw.set_annotations(annot_from_events)
eegdash/hbn/windows.py CHANGED
@@ -1,3 +1,15 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: GNU General Public License
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Windowing and trial processing utilities for HBN datasets.
6
+
7
+ This module provides functions for building trial tables, adding auxiliary anchors,
8
+ annotating trials with targets, and filtering recordings based on various criteria.
9
+ These utilities are specifically designed for working with HBN EEG data structures
10
+ and experimental paradigms.
11
+ """
12
+
1
13
  import logging
2
14
 
3
15
  import mne
@@ -7,11 +19,27 @@ from mne_bids import get_bids_path_from_fname
7
19
 
8
20
  from braindecode.datasets.base import BaseConcatDataset
9
21
 
10
- logger = logging.getLogger("eegdash")
11
-
12
22
 
13
23
  def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
14
- """One row per contrast trial with stimulus/response metrics."""
24
+ """Build a table of contrast trials from an events DataFrame.
25
+
26
+ This function processes a DataFrame of events (typically from a BIDS
27
+ `events.tsv` file) to identify contrast trials and extract relevant
28
+ metrics like stimulus onset, response onset, and reaction times.
29
+
30
+ Parameters
31
+ ----------
32
+ events_df : pandas.DataFrame
33
+ A DataFrame containing event information, with at least "onset" and
34
+ "value" columns.
35
+
36
+ Returns
37
+ -------
38
+ pandas.DataFrame
39
+ A DataFrame where each row represents a single contrast trial, with
40
+ columns for onsets, reaction times, and response correctness.
41
+
42
+ """
15
43
  events_df = events_df.copy()
16
44
  events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise")
17
45
  events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True)
@@ -82,12 +110,13 @@ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
82
110
  return pd.DataFrame(rows)
83
111
 
84
112
 
85
- # Aux functions to inject the annot
86
113
  def _to_float_or_none(x):
114
+ """Safely convert a value to float or None."""
87
115
  return None if pd.isna(x) else float(x)
88
116
 
89
117
 
90
118
  def _to_int_or_none(x):
119
+ """Safely convert a value to int or None."""
91
120
  if pd.isna(x):
92
121
  return None
93
122
  if isinstance(x, (bool, np.bool_)):
@@ -96,22 +125,55 @@ def _to_int_or_none(x):
96
125
  return int(x)
97
126
  try:
98
127
  return int(x)
99
- except Exception:
128
+ except (ValueError, TypeError):
100
129
  return None
101
130
 
102
131
 
103
132
  def _to_str_or_none(x):
133
+ """Safely convert a value to string or None."""
104
134
  return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)
105
135
 
106
136
 
107
137
  def annotate_trials_with_target(
108
- raw,
109
- target_field="rt_from_stimulus",
110
- epoch_length=2.0,
111
- require_stimulus=True,
112
- require_response=True,
113
- ):
114
- """Create 'contrast_trial_start' annotations with float target in extras."""
138
+ raw: mne.io.Raw,
139
+ target_field: str = "rt_from_stimulus",
140
+ epoch_length: float = 2.0,
141
+ require_stimulus: bool = True,
142
+ require_response: bool = True,
143
+ ) -> mne.io.Raw:
144
+ """Create trial annotations with a specified target value.
145
+
146
+ This function reads the BIDS events file associated with the `raw` object,
147
+ builds a trial table, and creates new MNE annotations for each trial.
148
+ The annotations are labeled "contrast_trial_start" and their `extras`
149
+ dictionary is populated with trial metrics, including a "target" key.
150
+
151
+ Parameters
152
+ ----------
153
+ raw : mne.io.Raw
154
+ The raw data object. Must have a single associated file name from
155
+ which the BIDS path can be derived.
156
+ target_field : str, default "rt_from_stimulus"
157
+ The column from the trial table to use as the "target" value in the
158
+ annotation extras.
159
+ epoch_length : float, default 2.0
160
+ The duration to set for each new annotation.
161
+ require_stimulus : bool, default True
162
+ If True, only include trials that have a recorded stimulus event.
163
+ require_response : bool, default True
164
+ If True, only include trials that have a recorded response event.
165
+
166
+ Returns
167
+ -------
168
+ mne.io.Raw
169
+ The `raw` object with the new annotations set.
170
+
171
+ Raises
172
+ ------
173
+ KeyError
174
+ If `target_field` is not a valid column in the built trial table.
175
+
176
+ """
115
177
  fnames = raw.filenames
116
178
  assert len(fnames) == 1, "Expected a single filename"
117
179
  bids_path = get_bids_path_from_fname(fnames[0])
@@ -142,7 +204,6 @@ def annotate_trials_with_target(
142
204
  extras = []
143
205
  for i, v in enumerate(targets):
144
206
  row = trials.iloc[i]
145
-
146
207
  extras.append(
147
208
  {
148
209
  "target": _to_float_or_none(v),
@@ -159,14 +220,39 @@ def annotate_trials_with_target(
159
220
  onset=onsets,
160
221
  duration=durations,
161
222
  description=descs,
162
- orig_time=raw.info["meas_date"],
223
+ orig_time=raw.info.get("meas_date"),
163
224
  extras=extras,
164
225
  )
165
226
  raw.set_annotations(new_ann, verbose=False)
166
227
  return raw
167
228
 
168
229
 
169
- def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor"):
230
+ def add_aux_anchors(
231
+ raw: mne.io.Raw,
232
+ stim_desc: str = "stimulus_anchor",
233
+ resp_desc: str = "response_anchor",
234
+ ) -> mne.io.Raw:
235
+ """Add auxiliary annotations for stimulus and response onsets.
236
+
237
+ This function inspects existing "contrast_trial_start" annotations and
238
+ adds new, zero-duration "anchor" annotations at the precise onsets of
239
+ stimuli and responses for each trial.
240
+
241
+ Parameters
242
+ ----------
243
+ raw : mne.io.Raw
244
+ The raw data object with "contrast_trial_start" annotations.
245
+ stim_desc : str, default "stimulus_anchor"
246
+ The description for the new stimulus annotations.
247
+ resp_desc : str, default "response_anchor"
248
+ The description for the new response annotations.
249
+
250
+ Returns
251
+ -------
252
+ mne.io.Raw
253
+ The `raw` object with the auxiliary annotations added.
254
+
255
+ """
170
256
  ann = raw.annotations
171
257
  mask = ann.description == "contrast_trial_start"
172
258
  if not np.any(mask):
@@ -179,28 +265,24 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
179
265
  ex = ann.extras[idx] if ann.extras is not None else {}
180
266
  t0 = float(ann.onset[idx])
181
267
 
182
- stim_t = ex["stimulus_onset"]
183
- resp_t = ex["response_onset"]
268
+ stim_t = ex.get("stimulus_onset")
269
+ resp_t = ex.get("response_onset")
184
270
 
185
271
  if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)):
186
- rtt = ex["rt_from_trialstart"]
187
- rts = ex["rt_from_stimulus"]
272
+ rtt = ex.get("rt_from_trialstart")
273
+ rts = ex.get("rt_from_stimulus")
188
274
  if rtt is not None and rts is not None:
189
275
  stim_t = t0 + float(rtt) - float(rts)
190
276
 
191
277
  if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)):
192
- rtt = ex["rt_from_trialstart"]
278
+ rtt = ex.get("rt_from_trialstart")
193
279
  if rtt is not None:
194
280
  resp_t = t0 + float(rtt)
195
281
 
196
- if (stim_t is not None) and not (
197
- isinstance(stim_t, float) and np.isnan(stim_t)
198
- ):
282
+ if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)):
199
283
  stim_onsets.append(float(stim_t))
200
284
  stim_extras.append(dict(ex, anchor="stimulus"))
201
- if (resp_t is not None) and not (
202
- isinstance(resp_t, float) and np.isnan(resp_t)
203
- ):
285
+ if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)):
204
286
  resp_onsets.append(float(resp_t))
205
287
  resp_extras.append(dict(ex, anchor="response"))
206
288
 
@@ -210,7 +292,7 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
210
292
  onset=new_onsets,
211
293
  duration=np.zeros_like(new_onsets, dtype=float),
212
294
  description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets),
213
- orig_time=raw.info["meas_date"],
295
+ orig_time=raw.info.get("meas_date"),
214
296
  extras=stim_extras + resp_extras,
215
297
  )
216
298
  raw.set_annotations(ann + aux, verbose=False)
@@ -218,10 +300,10 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
218
300
 
219
301
 
220
302
  def add_extras_columns(
221
- windows_concat_ds,
222
- original_concat_ds,
223
- desc="contrast_trial_start",
224
- keys=(
303
+ windows_concat_ds: BaseConcatDataset,
304
+ original_concat_ds: BaseConcatDataset,
305
+ desc: str = "contrast_trial_start",
306
+ keys: tuple = (
225
307
  "target",
226
308
  "rt_from_stimulus",
227
309
  "rt_from_trialstart",
@@ -230,7 +312,31 @@ def add_extras_columns(
230
312
  "correct",
231
313
  "response_type",
232
314
  ),
233
- ):
315
+ ) -> BaseConcatDataset:
316
+ """Add columns from annotation extras to a windowed dataset's metadata.
317
+
318
+ This function propagates trial-level information stored in the `extras`
319
+ of annotations to the `metadata` DataFrame of a `WindowsDataset`.
320
+
321
+ Parameters
322
+ ----------
323
+ windows_concat_ds : BaseConcatDataset
324
+ The windowed dataset whose metadata will be updated.
325
+ original_concat_ds : BaseConcatDataset
326
+ The original (non-windowed) dataset containing the raw data and
327
+ annotations with the `extras` to be added.
328
+ desc : str, default "contrast_trial_start"
329
+ The description of the annotations to source the extras from.
330
+ keys : tuple, default (...)
331
+ The keys to extract from each annotation's `extras` dictionary and
332
+ add as columns to the metadata.
333
+
334
+ Returns
335
+ -------
336
+ BaseConcatDataset
337
+ The `windows_concat_ds` with updated metadata.
338
+
339
+ """
234
340
  float_cols = {
235
341
  "target",
236
342
  "rt_from_stimulus",
@@ -282,7 +388,6 @@ def add_extras_columns(
282
388
  else: # response_type
283
389
  ser = pd.Series(vals, index=md.index, dtype="string")
284
390
 
285
- # Replace the whole column to avoid dtype conflicts
286
391
  md[k] = ser
287
392
 
288
393
  win_ds.metadata = md.reset_index(drop=True)
@@ -293,7 +398,25 @@ def add_extras_columns(
293
398
  return windows_concat_ds
294
399
 
295
400
 
296
- def keep_only_recordings_with(desc, concat_ds):
401
+ def keep_only_recordings_with(
402
+ desc: str, concat_ds: BaseConcatDataset
403
+ ) -> BaseConcatDataset:
404
+ """Filter a concatenated dataset to keep only recordings with a specific annotation.
405
+
406
+ Parameters
407
+ ----------
408
+ desc : str
409
+ The description of the annotation that must be present in a recording
410
+ for it to be kept.
411
+ concat_ds : BaseConcatDataset
412
+ The concatenated dataset to filter.
413
+
414
+ Returns
415
+ -------
416
+ BaseConcatDataset
417
+ A new concatenated dataset containing only the filtered recordings.
418
+
419
+ """
297
420
  kept = []
298
421
  for ds in concat_ds.datasets:
299
422
  if np.any(ds.raw.annotations.description == desc):