eegdash 0.4.0.dev153__py3-none-any.whl → 0.4.0.dev162__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -1,7 +1,8 @@
1
1
  """Convenience functions for storing and loading features datasets.
2
2
 
3
- See Also:
4
- https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
3
+ See Also
4
+ --------
5
+ https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
5
6
 
6
7
  """
7
8
 
@@ -16,34 +17,40 @@ from braindecode.datautil.serialization import _load_kwargs_json
16
17
  from .datasets import FeaturesConcatDataset, FeaturesDataset
17
18
 
18
19
 
19
- def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
20
- """Load a stored features dataset from files.
20
+ def load_features_concat_dataset(
21
+ path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1
22
+ ) -> FeaturesConcatDataset:
23
+ """Load a stored `FeaturesConcatDataset` from a directory.
24
+
25
+ This function reconstructs a :class:`FeaturesConcatDataset` by loading
26
+ individual :class:`FeaturesDataset` instances from subdirectories within
27
+ the given path. It uses joblib for parallel loading.
21
28
 
22
29
  Parameters
23
30
  ----------
24
- path: str | pathlib.Path
25
- Path to the directory of the .fif / -epo.fif and .json files.
26
- ids_to_load: list of int | None
27
- Ids of specific files to load.
28
- n_jobs: int
29
- Number of jobs to be used to read files in parallel.
31
+ path : str or pathlib.Path
32
+ The path to the directory where the dataset was saved. This directory
33
+ should contain subdirectories (e.g., "0", "1", "2", ...) for each
34
+ individual dataset.
35
+ ids_to_load : list of int, optional
36
+ A list of specific dataset IDs (subdirectory names) to load. If None,
37
+ all subdirectories in the path will be loaded.
38
+ n_jobs : int, default 1
39
+ The number of jobs to use for parallel loading. -1 means using all
40
+ processors.
30
41
 
31
42
  Returns
32
43
  -------
33
- concat_dataset: eegdash.features.datasets.FeaturesConcatDataset
34
- A concatenation of multiple eegdash.features.datasets.FeaturesDataset
35
- instances loaded from the given directory.
44
+ eegdash.features.datasets.FeaturesConcatDataset
45
+ A concatenated dataset containing the loaded `FeaturesDataset` instances.
36
46
 
37
47
  """
38
48
  # Make sure we always work with a pathlib.Path
39
49
  path = Path(path)
40
50
 
41
- # else we have a dataset saved in the new way with subdirectories in path
42
- # for every dataset with description.json and -feat.parquet,
43
- # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
44
- # window_preproc_kwargs.json, features_kwargs.json
45
51
  if ids_to_load is None:
46
- ids_to_load = [p.name for p in path.iterdir()]
52
+ # Get all subdirectories and sort them numerically
53
+ ids_to_load = [p.name for p in path.iterdir() if p.is_dir()]
47
54
  ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
48
55
  ids_to_load = [str(i) for i in ids_to_load]
49
56
 
@@ -51,7 +58,26 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
51
58
  return FeaturesConcatDataset(datasets)
52
59
 
53
60
 
54
- def _load_parallel(path, i):
61
+ def _load_parallel(path: Path, i: str) -> FeaturesDataset:
62
+ """Load a single `FeaturesDataset` from its subdirectory.
63
+
64
+ This is a helper function for `load_features_concat_dataset` that handles
65
+ the loading of one dataset's files (features, metadata, descriptions, etc.).
66
+
67
+ Parameters
68
+ ----------
69
+ path : pathlib.Path
70
+ The root directory of the saved `FeaturesConcatDataset`.
71
+ i : str
72
+ The identifier of the dataset to load, corresponding to its
73
+ subdirectory name.
74
+
75
+ Returns
76
+ -------
77
+ eegdash.features.datasets.FeaturesDataset
78
+ The loaded dataset instance.
79
+
80
+ """
55
81
  sub_dir = path / i
56
82
 
57
83
  parquet_name_pattern = "{}-feat.parquet"
eegdash/features/utils.py CHANGED
@@ -22,7 +22,28 @@ def _extract_features_from_windowsdataset(
22
22
  win_ds: EEGWindowsDataset | WindowsDataset,
23
23
  feature_extractor: FeatureExtractor,
24
24
  batch_size: int = 512,
25
- ):
25
+ ) -> FeaturesDataset:
26
+ """Extract features from a single `WindowsDataset`.
27
+
28
+ This is a helper function that iterates through a `WindowsDataset` in
29
+ batches, applies a `FeatureExtractor`, and returns the results as a
30
+ `FeaturesDataset`.
31
+
32
+ Parameters
33
+ ----------
34
+ win_ds : EEGWindowsDataset or WindowsDataset
35
+ The windowed dataset to extract features from.
36
+ feature_extractor : FeatureExtractor
37
+ The feature extractor instance to apply.
38
+ batch_size : int, default 512
39
+ The number of windows to process in each batch.
40
+
41
+ Returns
42
+ -------
43
+ FeaturesDataset
44
+ A new dataset containing the extracted features and associated metadata.
45
+
46
+ """
26
47
  metadata = win_ds.metadata
27
48
  if not win_ds.targets_from == "metadata":
28
49
  metadata = copy.deepcopy(metadata)
@@ -51,18 +72,16 @@ def _extract_features_from_windowsdataset(
51
72
  features_dict[k].extend(v)
52
73
  features_df = pd.DataFrame(features_dict)
53
74
  if not win_ds.targets_from == "metadata":
54
- metadata.set_index("orig_index", drop=False, inplace=True)
55
75
  metadata.reset_index(drop=True, inplace=True)
56
- metadata.drop("orig_index", axis=1, inplace=True)
76
+ metadata.drop("orig_index", axis=1, inplace=True, errors="ignore")
57
77
 
58
- # FUTURE: truly support WindowsDataset objects
59
78
  return FeaturesDataset(
60
79
  features_df,
61
80
  metadata=metadata,
62
81
  description=win_ds.description,
63
82
  raw_info=win_ds.raw.info,
64
- raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
65
- window_kwargs=win_ds.window_kwargs,
83
+ raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None),
84
+ window_kwargs=getattr(win_ds, "window_kwargs", None),
66
85
  features_kwargs=feature_extractor.features_kwargs,
67
86
  )
68
87
 
@@ -73,7 +92,34 @@ def extract_features(
73
92
  *,
74
93
  batch_size: int = 512,
75
94
  n_jobs: int = 1,
76
- ):
95
+ ) -> FeaturesConcatDataset:
96
+ """Extract features from a concatenated dataset of windows.
97
+
98
+ This function applies a feature extractor to each `WindowsDataset` within a
99
+ `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset`
100
+ with the results.
101
+
102
+ Parameters
103
+ ----------
104
+ concat_dataset : BaseConcatDataset
105
+ A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset`
106
+ instances.
107
+ features : FeatureExtractor or dict or list
108
+ The feature extractor(s) to apply. Can be a `FeatureExtractor`
109
+ instance, a dictionary of named feature functions, or a list of
110
+ feature functions.
111
+ batch_size : int, default 512
112
+ The size of batches to use for feature extraction.
113
+ n_jobs : int, default 1
114
+ The number of parallel jobs to use for extracting features from the
115
+ datasets.
116
+
117
+ Returns
118
+ -------
119
+ FeaturesConcatDataset
120
+ A new concatenated dataset containing the extracted features.
121
+
122
+ """
77
123
  if isinstance(features, list):
78
124
  features = dict(enumerate(features))
79
125
  if not isinstance(features, FeatureExtractor):
@@ -97,7 +143,28 @@ def fit_feature_extractors(
97
143
  concat_dataset: BaseConcatDataset,
98
144
  features: FeatureExtractor | Dict[str, Callable] | List[Callable],
99
145
  batch_size: int = 8192,
100
- ):
146
+ ) -> FeatureExtractor:
147
+ """Fit trainable feature extractors on a dataset.
148
+
149
+ If the provided feature extractor (or any of its sub-extractors) is
150
+ trainable (i.e., subclasses `TrainableFeature`), this function iterates
151
+ through the dataset to fit it.
152
+
153
+ Parameters
154
+ ----------
155
+ concat_dataset : BaseConcatDataset
156
+ The dataset to use for fitting the feature extractors.
157
+ features : FeatureExtractor or dict or list
158
+ The feature extractor(s) to fit.
159
+ batch_size : int, default 8192
160
+ The batch size to use when iterating through the dataset for fitting.
161
+
162
+ Returns
163
+ -------
164
+ FeatureExtractor
165
+ The fitted feature extractor.
166
+
167
+ """
101
168
  if isinstance(features, list):
102
169
  features = dict(enumerate(features))
103
170
  if not isinstance(features, FeatureExtractor):
@@ -18,27 +18,47 @@ from ..logging import logger
18
18
 
19
19
 
20
20
  class hbn_ec_ec_reannotation(Preprocessor):
21
- """Preprocessor to reannotate the raw data for eyes open and eyes closed events.
21
+ """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events.
22
22
 
23
- This processor is designed for HBN datasets.
23
+ This preprocessor is specifically designed for Healthy Brain Network (HBN)
24
+ datasets. It identifies existing annotations for "instructed_toCloseEyes"
25
+ and "instructed_toOpenEyes" and creates new, regularly spaced annotations
26
+ for "eyes_closed" and "eyes_open" segments, respectively.
27
+
28
+ This is useful for creating windowed datasets based on these new, more
29
+ precise event markers.
30
+
31
+ Notes
32
+ -----
33
+ This class inherits from :class:`braindecode.preprocessing.Preprocessor`
34
+ and is intended to be used within a braindecode preprocessing pipeline.
24
35
 
25
36
  """
26
37
 
27
38
  def __init__(self):
28
39
  super().__init__(fn=self.transform, apply_on_array=False)
29
40
 
30
- def transform(self, raw):
31
- """Reannotate the raw data to create new events for eyes open and eyes closed
41
+ def transform(self, raw: mne.io.Raw) -> mne.io.Raw:
42
+ """Create new annotations for eyes-open and eyes-closed periods.
32
43
 
33
- This function modifies the raw MNE object by creating new events based on
34
- the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes".
35
- It generates new events every 2 seconds within specified time ranges after
36
- the original events, and replaces the existing annotations with these new events.
44
+ This function finds the original "instructed_to..." annotations and
45
+ generates new annotations every 2 seconds within specific time ranges
46
+ relative to the original markers:
47
+ - "eyes_closed": 15s to 29s after "instructed_toCloseEyes"
48
+ - "eyes_open": 5s to 19s after "instructed_toOpenEyes"
49
+
50
+ The original annotations in the `mne.io.Raw` object are replaced by
51
+ this new set of annotations.
37
52
 
38
53
  Parameters
39
54
  ----------
40
55
  raw : mne.io.Raw
41
- The raw MNE object containing EEG data and annotations.
56
+ The raw MNE object containing the HBN data and original annotations.
57
+
58
+ Returns
59
+ -------
60
+ mne.io.Raw
61
+ The raw MNE object with the modified annotations.
42
62
 
43
63
  """
44
64
  events, event_id = mne.events_from_annotations(raw)
@@ -48,15 +68,27 @@ class hbn_ec_ec_reannotation(Preprocessor):
48
68
  # Create new events array for 2-second segments
49
69
  new_events = []
50
70
  sfreq = raw.info["sfreq"]
51
- for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]:
52
- # For each original event, create events every 2 seconds from 15s to 29s after
53
- start_times = event[0] + np.arange(15, 29, 2) * sfreq
54
- new_events.extend([[int(t), 0, 1] for t in start_times])
55
71
 
56
- for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]:
57
- # For each original event, create events every 2 seconds from 5s to 19s after
58
- start_times = event[0] + np.arange(5, 19, 2) * sfreq
59
- new_events.extend([[int(t), 0, 2] for t in start_times])
72
+ close_event_id = event_id.get("instructed_toCloseEyes")
73
+ if close_event_id:
74
+ for event in events[events[:, 2] == close_event_id]:
75
+ # For each original event, create events every 2s from 15s to 29s after
76
+ start_times = event[0] + np.arange(15, 29, 2) * sfreq
77
+ new_events.extend([[int(t), 0, 1] for t in start_times])
78
+
79
+ open_event_id = event_id.get("instructed_toOpenEyes")
80
+ if open_event_id:
81
+ for event in events[events[:, 2] == open_event_id]:
82
+ # For each original event, create events every 2s from 5s to 19s after
83
+ start_times = event[0] + np.arange(5, 19, 2) * sfreq
84
+ new_events.extend([[int(t), 0, 2] for t in start_times])
85
+
86
+ if not new_events:
87
+ logger.warning(
88
+ "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' "
89
+ "annotations. No new events created."
90
+ )
91
+ return raw
60
92
 
61
93
  # replace events in raw
62
94
  new_events = np.array(new_events)
@@ -65,6 +97,7 @@ class hbn_ec_ec_reannotation(Preprocessor):
65
97
  events=new_events,
66
98
  event_desc={1: "eyes_closed", 2: "eyes_open"},
67
99
  sfreq=raw.info["sfreq"],
100
+ orig_time=raw.info.get("meas_date"),
68
101
  )
69
102
 
70
103
  raw.set_annotations(annot_from_events)
eegdash/hbn/windows.py CHANGED
@@ -21,7 +21,25 @@ from braindecode.datasets.base import BaseConcatDataset
21
21
 
22
22
 
23
23
  def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
24
- """One row per contrast trial with stimulus/response metrics."""
24
+ """Build a table of contrast trials from an events DataFrame.
25
+
26
+ This function processes a DataFrame of events (typically from a BIDS
27
+ `events.tsv` file) to identify contrast trials and extract relevant
28
+ metrics like stimulus onset, response onset, and reaction times.
29
+
30
+ Parameters
31
+ ----------
32
+ events_df : pandas.DataFrame
33
+ A DataFrame containing event information, with at least "onset" and
34
+ "value" columns.
35
+
36
+ Returns
37
+ -------
38
+ pandas.DataFrame
39
+ A DataFrame where each row represents a single contrast trial, with
40
+ columns for onsets, reaction times, and response correctness.
41
+
42
+ """
25
43
  events_df = events_df.copy()
26
44
  events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise")
27
45
  events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True)
@@ -92,12 +110,13 @@ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
92
110
  return pd.DataFrame(rows)
93
111
 
94
112
 
95
- # Aux functions to inject the annot
96
113
  def _to_float_or_none(x):
114
+ """Safely convert a value to float or None."""
97
115
  return None if pd.isna(x) else float(x)
98
116
 
99
117
 
100
118
  def _to_int_or_none(x):
119
+ """Safely convert a value to int or None."""
101
120
  if pd.isna(x):
102
121
  return None
103
122
  if isinstance(x, (bool, np.bool_)):
@@ -106,22 +125,55 @@ def _to_int_or_none(x):
106
125
  return int(x)
107
126
  try:
108
127
  return int(x)
109
- except Exception:
128
+ except (ValueError, TypeError):
110
129
  return None
111
130
 
112
131
 
113
132
  def _to_str_or_none(x):
133
+ """Safely convert a value to string or None."""
114
134
  return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)
115
135
 
116
136
 
117
137
  def annotate_trials_with_target(
118
- raw,
119
- target_field="rt_from_stimulus",
120
- epoch_length=2.0,
121
- require_stimulus=True,
122
- require_response=True,
123
- ):
124
- """Create 'contrast_trial_start' annotations with float target in extras."""
138
+ raw: mne.io.Raw,
139
+ target_field: str = "rt_from_stimulus",
140
+ epoch_length: float = 2.0,
141
+ require_stimulus: bool = True,
142
+ require_response: bool = True,
143
+ ) -> mne.io.Raw:
144
+ """Create trial annotations with a specified target value.
145
+
146
+ This function reads the BIDS events file associated with the `raw` object,
147
+ builds a trial table, and creates new MNE annotations for each trial.
148
+ The annotations are labeled "contrast_trial_start" and their `extras`
149
+ dictionary is populated with trial metrics, including a "target" key.
150
+
151
+ Parameters
152
+ ----------
153
+ raw : mne.io.Raw
154
+ The raw data object. Must have a single associated file name from
155
+ which the BIDS path can be derived.
156
+ target_field : str, default "rt_from_stimulus"
157
+ The column from the trial table to use as the "target" value in the
158
+ annotation extras.
159
+ epoch_length : float, default 2.0
160
+ The duration to set for each new annotation.
161
+ require_stimulus : bool, default True
162
+ If True, only include trials that have a recorded stimulus event.
163
+ require_response : bool, default True
164
+ If True, only include trials that have a recorded response event.
165
+
166
+ Returns
167
+ -------
168
+ mne.io.Raw
169
+ The `raw` object with the new annotations set.
170
+
171
+ Raises
172
+ ------
173
+ KeyError
174
+ If `target_field` is not a valid column in the built trial table.
175
+
176
+ """
125
177
  fnames = raw.filenames
126
178
  assert len(fnames) == 1, "Expected a single filename"
127
179
  bids_path = get_bids_path_from_fname(fnames[0])
@@ -152,7 +204,6 @@ def annotate_trials_with_target(
152
204
  extras = []
153
205
  for i, v in enumerate(targets):
154
206
  row = trials.iloc[i]
155
-
156
207
  extras.append(
157
208
  {
158
209
  "target": _to_float_or_none(v),
@@ -169,14 +220,39 @@ def annotate_trials_with_target(
169
220
  onset=onsets,
170
221
  duration=durations,
171
222
  description=descs,
172
- orig_time=raw.info["meas_date"],
223
+ orig_time=raw.info.get("meas_date"),
173
224
  extras=extras,
174
225
  )
175
226
  raw.set_annotations(new_ann, verbose=False)
176
227
  return raw
177
228
 
178
229
 
179
- def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor"):
230
+ def add_aux_anchors(
231
+ raw: mne.io.Raw,
232
+ stim_desc: str = "stimulus_anchor",
233
+ resp_desc: str = "response_anchor",
234
+ ) -> mne.io.Raw:
235
+ """Add auxiliary annotations for stimulus and response onsets.
236
+
237
+ This function inspects existing "contrast_trial_start" annotations and
238
+ adds new, zero-duration "anchor" annotations at the precise onsets of
239
+ stimuli and responses for each trial.
240
+
241
+ Parameters
242
+ ----------
243
+ raw : mne.io.Raw
244
+ The raw data object with "contrast_trial_start" annotations.
245
+ stim_desc : str, default "stimulus_anchor"
246
+ The description for the new stimulus annotations.
247
+ resp_desc : str, default "response_anchor"
248
+ The description for the new response annotations.
249
+
250
+ Returns
251
+ -------
252
+ mne.io.Raw
253
+ The `raw` object with the auxiliary annotations added.
254
+
255
+ """
180
256
  ann = raw.annotations
181
257
  mask = ann.description == "contrast_trial_start"
182
258
  if not np.any(mask):
@@ -189,28 +265,24 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
189
265
  ex = ann.extras[idx] if ann.extras is not None else {}
190
266
  t0 = float(ann.onset[idx])
191
267
 
192
- stim_t = ex["stimulus_onset"]
193
- resp_t = ex["response_onset"]
268
+ stim_t = ex.get("stimulus_onset")
269
+ resp_t = ex.get("response_onset")
194
270
 
195
271
  if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)):
196
- rtt = ex["rt_from_trialstart"]
197
- rts = ex["rt_from_stimulus"]
272
+ rtt = ex.get("rt_from_trialstart")
273
+ rts = ex.get("rt_from_stimulus")
198
274
  if rtt is not None and rts is not None:
199
275
  stim_t = t0 + float(rtt) - float(rts)
200
276
 
201
277
  if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)):
202
- rtt = ex["rt_from_trialstart"]
278
+ rtt = ex.get("rt_from_trialstart")
203
279
  if rtt is not None:
204
280
  resp_t = t0 + float(rtt)
205
281
 
206
- if (stim_t is not None) and not (
207
- isinstance(stim_t, float) and np.isnan(stim_t)
208
- ):
282
+ if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)):
209
283
  stim_onsets.append(float(stim_t))
210
284
  stim_extras.append(dict(ex, anchor="stimulus"))
211
- if (resp_t is not None) and not (
212
- isinstance(resp_t, float) and np.isnan(resp_t)
213
- ):
285
+ if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)):
214
286
  resp_onsets.append(float(resp_t))
215
287
  resp_extras.append(dict(ex, anchor="response"))
216
288
 
@@ -220,7 +292,7 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
220
292
  onset=new_onsets,
221
293
  duration=np.zeros_like(new_onsets, dtype=float),
222
294
  description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets),
223
- orig_time=raw.info["meas_date"],
295
+ orig_time=raw.info.get("meas_date"),
224
296
  extras=stim_extras + resp_extras,
225
297
  )
226
298
  raw.set_annotations(ann + aux, verbose=False)
@@ -228,10 +300,10 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
228
300
 
229
301
 
230
302
  def add_extras_columns(
231
- windows_concat_ds,
232
- original_concat_ds,
233
- desc="contrast_trial_start",
234
- keys=(
303
+ windows_concat_ds: BaseConcatDataset,
304
+ original_concat_ds: BaseConcatDataset,
305
+ desc: str = "contrast_trial_start",
306
+ keys: tuple = (
235
307
  "target",
236
308
  "rt_from_stimulus",
237
309
  "rt_from_trialstart",
@@ -240,7 +312,31 @@ def add_extras_columns(
240
312
  "correct",
241
313
  "response_type",
242
314
  ),
243
- ):
315
+ ) -> BaseConcatDataset:
316
+ """Add columns from annotation extras to a windowed dataset's metadata.
317
+
318
+ This function propagates trial-level information stored in the `extras`
319
+ of annotations to the `metadata` DataFrame of a `WindowsDataset`.
320
+
321
+ Parameters
322
+ ----------
323
+ windows_concat_ds : BaseConcatDataset
324
+ The windowed dataset whose metadata will be updated.
325
+ original_concat_ds : BaseConcatDataset
326
+ The original (non-windowed) dataset containing the raw data and
327
+ annotations with the `extras` to be added.
328
+ desc : str, default "contrast_trial_start"
329
+ The description of the annotations to source the extras from.
330
+ keys : tuple, default (...)
331
+ The keys to extract from each annotation's `extras` dictionary and
332
+ add as columns to the metadata.
333
+
334
+ Returns
335
+ -------
336
+ BaseConcatDataset
337
+ The `windows_concat_ds` with updated metadata.
338
+
339
+ """
244
340
  float_cols = {
245
341
  "target",
246
342
  "rt_from_stimulus",
@@ -292,7 +388,6 @@ def add_extras_columns(
292
388
  else: # response_type
293
389
  ser = pd.Series(vals, index=md.index, dtype="string")
294
390
 
295
- # Replace the whole column to avoid dtype conflicts
296
391
  md[k] = ser
297
392
 
298
393
  win_ds.metadata = md.reset_index(drop=True)
@@ -303,7 +398,25 @@ def add_extras_columns(
303
398
  return windows_concat_ds
304
399
 
305
400
 
306
- def keep_only_recordings_with(desc, concat_ds):
401
+ def keep_only_recordings_with(
402
+ desc: str, concat_ds: BaseConcatDataset
403
+ ) -> BaseConcatDataset:
404
+ """Filter a concatenated dataset to keep only recordings with a specific annotation.
405
+
406
+ Parameters
407
+ ----------
408
+ desc : str
409
+ The description of the annotation that must be present in a recording
410
+ for it to be kept.
411
+ concat_ds : BaseConcatDataset
412
+ The concatenated dataset to filter.
413
+
414
+ Returns
415
+ -------
416
+ BaseConcatDataset
417
+ A new concatenated dataset containing only the filtered recordings.
418
+
419
+ """
307
420
  kept = []
308
421
  for ds in concat_ds.datasets:
309
422
  if np.any(ds.raw.annotations.description == desc):
eegdash/logging.py CHANGED
@@ -29,6 +29,25 @@ root_logger.setLevel(logging.INFO)
29
29
  # Now, get your package-specific logger. It will inherit the
30
30
  # configuration from the root logger we just set up.
31
31
  logger = logging.getLogger("eegdash")
32
+ """The primary logger for the EEGDash package.
33
+
34
+ This logger is configured to use :class:`rich.logging.RichHandler` for
35
+ formatted, colorful output in the console. It inherits its base configuration
36
+ from the root logger, which is set to the ``INFO`` level.
37
+
38
+ Examples
39
+ --------
40
+ Usage in other modules:
41
+
42
+ .. code-block:: python
43
+
44
+ from .logging import logger
45
+
46
+ logger.info("This is an informational message.")
47
+ logger.warning("This is a warning.")
48
+ logger.error("This is an error.")
49
+ """
50
+
32
51
 
33
52
  logger.setLevel(logging.INFO)
34
53