eegdash 0.3.3.dev61__py3-none-any.whl → 0.5.0.dev180784713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eegdash/__init__.py +19 -6
- eegdash/api.py +336 -539
- eegdash/bids_eeg_metadata.py +495 -0
- eegdash/const.py +349 -0
- eegdash/dataset/__init__.py +28 -0
- eegdash/dataset/base.py +311 -0
- eegdash/dataset/bids_dataset.py +641 -0
- eegdash/dataset/dataset.py +692 -0
- eegdash/dataset/dataset_summary.csv +255 -0
- eegdash/dataset/registry.py +287 -0
- eegdash/downloader.py +197 -0
- eegdash/features/__init__.py +15 -13
- eegdash/features/datasets.py +329 -138
- eegdash/features/decorators.py +105 -13
- eegdash/features/extractors.py +233 -63
- eegdash/features/feature_bank/__init__.py +12 -12
- eegdash/features/feature_bank/complexity.py +22 -20
- eegdash/features/feature_bank/connectivity.py +27 -28
- eegdash/features/feature_bank/csp.py +3 -1
- eegdash/features/feature_bank/dimensionality.py +6 -6
- eegdash/features/feature_bank/signal.py +29 -30
- eegdash/features/feature_bank/spectral.py +40 -44
- eegdash/features/feature_bank/utils.py +8 -0
- eegdash/features/inspect.py +126 -15
- eegdash/features/serialization.py +58 -17
- eegdash/features/utils.py +90 -16
- eegdash/hbn/__init__.py +28 -0
- eegdash/hbn/preprocessing.py +105 -0
- eegdash/hbn/windows.py +428 -0
- eegdash/logging.py +54 -0
- eegdash/mongodb.py +55 -24
- eegdash/paths.py +52 -0
- eegdash/utils.py +29 -1
- eegdash-0.5.0.dev180784713.dist-info/METADATA +121 -0
- eegdash-0.5.0.dev180784713.dist-info/RECORD +38 -0
- eegdash-0.5.0.dev180784713.dist-info/licenses/LICENSE +29 -0
- eegdash/data_config.py +0 -34
- eegdash/data_utils.py +0 -687
- eegdash/dataset.py +0 -69
- eegdash/preprocessing.py +0 -63
- eegdash-0.3.3.dev61.dist-info/METADATA +0 -192
- eegdash-0.3.3.dev61.dist-info/RECORD +0 -28
- eegdash-0.3.3.dev61.dist-info/licenses/LICENSE +0 -23
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/WHEEL +0 -0
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/top_level.txt +0 -0
eegdash/hbn/windows.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
# Authors: The EEGDash contributors.
|
|
2
|
+
# License: BSD-3-Clause
|
|
3
|
+
# Copyright the EEGDash contributors.
|
|
4
|
+
|
|
5
|
+
"""Windowing and trial processing utilities for HBN datasets.
|
|
6
|
+
|
|
7
|
+
This module provides functions for building trial tables, adding auxiliary anchors,
|
|
8
|
+
annotating trials with targets, and filtering recordings based on various criteria.
|
|
9
|
+
These utilities are specifically designed for working with HBN EEG data structures
|
|
10
|
+
and experimental paradigms.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
import mne
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from mne_bids import get_bids_path_from_fname
|
|
19
|
+
|
|
20
|
+
from braindecode.datasets.base import BaseConcatDataset
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
|
|
24
|
+
"""Build a table of contrast trials from an events DataFrame.
|
|
25
|
+
|
|
26
|
+
This function processes a DataFrame of events (typically from a BIDS
|
|
27
|
+
`events.tsv` file) to identify contrast trials and extract relevant
|
|
28
|
+
metrics like stimulus onset, response onset, and reaction times.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
events_df : pandas.DataFrame
|
|
33
|
+
A DataFrame containing event information, with at least "onset" and
|
|
34
|
+
"value" columns.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
pandas.DataFrame
|
|
39
|
+
A DataFrame where each row represents a single contrast trial, with
|
|
40
|
+
columns for onsets, reaction times, and response correctness.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
events_df = events_df.copy()
|
|
44
|
+
events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise")
|
|
45
|
+
events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True)
|
|
46
|
+
|
|
47
|
+
trials = events_df[events_df["value"].eq("contrastTrial_start")].copy()
|
|
48
|
+
stimuli = events_df[events_df["value"].isin(["left_target", "right_target"])].copy()
|
|
49
|
+
responses = events_df[
|
|
50
|
+
events_df["value"].isin(["left_buttonPress", "right_buttonPress"])
|
|
51
|
+
].copy()
|
|
52
|
+
|
|
53
|
+
trials = trials.reset_index(drop=True)
|
|
54
|
+
trials["next_onset"] = trials["onset"].shift(-1)
|
|
55
|
+
trials = trials.dropna(subset=["next_onset"]).reset_index(drop=True)
|
|
56
|
+
|
|
57
|
+
rows = []
|
|
58
|
+
for _, tr in trials.iterrows():
|
|
59
|
+
start = float(tr["onset"])
|
|
60
|
+
end = float(tr["next_onset"])
|
|
61
|
+
|
|
62
|
+
stim_block = stimuli[(stimuli["onset"] >= start) & (stimuli["onset"] < end)]
|
|
63
|
+
stim_onset = np.nan if stim_block.empty else float(stim_block.iloc[0]["onset"])
|
|
64
|
+
|
|
65
|
+
if not np.isnan(stim_onset):
|
|
66
|
+
resp_block = responses[
|
|
67
|
+
(responses["onset"] >= stim_onset) & (responses["onset"] < end)
|
|
68
|
+
]
|
|
69
|
+
else:
|
|
70
|
+
resp_block = responses[
|
|
71
|
+
(responses["onset"] >= start) & (responses["onset"] < end)
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
if resp_block.empty:
|
|
75
|
+
resp_onset = np.nan
|
|
76
|
+
resp_type = None
|
|
77
|
+
feedback = None
|
|
78
|
+
else:
|
|
79
|
+
resp_onset = float(resp_block.iloc[0]["onset"])
|
|
80
|
+
resp_type = resp_block.iloc[0]["value"]
|
|
81
|
+
feedback = resp_block.iloc[0]["feedback"]
|
|
82
|
+
|
|
83
|
+
rt_from_stim = (
|
|
84
|
+
(resp_onset - stim_onset)
|
|
85
|
+
if (not np.isnan(stim_onset) and not np.isnan(resp_onset))
|
|
86
|
+
else np.nan
|
|
87
|
+
)
|
|
88
|
+
rt_from_trial = (resp_onset - start) if not np.isnan(resp_onset) else np.nan
|
|
89
|
+
|
|
90
|
+
correct = None
|
|
91
|
+
if isinstance(feedback, str):
|
|
92
|
+
if feedback == "smiley_face":
|
|
93
|
+
correct = True
|
|
94
|
+
elif feedback == "sad_face":
|
|
95
|
+
correct = False
|
|
96
|
+
|
|
97
|
+
rows.append(
|
|
98
|
+
{
|
|
99
|
+
"trial_start_onset": start,
|
|
100
|
+
"trial_stop_onset": end,
|
|
101
|
+
"stimulus_onset": stim_onset,
|
|
102
|
+
"response_onset": resp_onset,
|
|
103
|
+
"rt_from_stimulus": rt_from_stim,
|
|
104
|
+
"rt_from_trialstart": rt_from_trial,
|
|
105
|
+
"response_type": resp_type,
|
|
106
|
+
"correct": correct,
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return pd.DataFrame(rows)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _to_float_or_none(x):
|
|
114
|
+
"""Safely convert a value to float or None."""
|
|
115
|
+
return None if pd.isna(x) else float(x)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _to_int_or_none(x):
|
|
119
|
+
"""Safely convert a value to int or None."""
|
|
120
|
+
if pd.isna(x):
|
|
121
|
+
return None
|
|
122
|
+
if isinstance(x, (bool, np.bool_)):
|
|
123
|
+
return int(bool(x))
|
|
124
|
+
if isinstance(x, (int, np.integer)):
|
|
125
|
+
return int(x)
|
|
126
|
+
try:
|
|
127
|
+
return int(x)
|
|
128
|
+
except (ValueError, TypeError):
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _to_str_or_none(x):
|
|
133
|
+
"""Safely convert a value to string or None."""
|
|
134
|
+
return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def annotate_trials_with_target(
|
|
138
|
+
raw: mne.io.Raw,
|
|
139
|
+
target_field: str = "rt_from_stimulus",
|
|
140
|
+
epoch_length: float = 2.0,
|
|
141
|
+
require_stimulus: bool = True,
|
|
142
|
+
require_response: bool = True,
|
|
143
|
+
) -> mne.io.Raw:
|
|
144
|
+
"""Create trial annotations with a specified target value.
|
|
145
|
+
|
|
146
|
+
This function reads the BIDS events file associated with the `raw` object,
|
|
147
|
+
builds a trial table, and creates new MNE annotations for each trial.
|
|
148
|
+
The annotations are labeled "contrast_trial_start" and their `extras`
|
|
149
|
+
dictionary is populated with trial metrics, including a "target" key.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
raw : mne.io.Raw
|
|
154
|
+
The raw data object. Must have a single associated file name from
|
|
155
|
+
which the BIDS path can be derived.
|
|
156
|
+
target_field : str, default "rt_from_stimulus"
|
|
157
|
+
The column from the trial table to use as the "target" value in the
|
|
158
|
+
annotation extras.
|
|
159
|
+
epoch_length : float, default 2.0
|
|
160
|
+
The duration to set for each new annotation.
|
|
161
|
+
require_stimulus : bool, default True
|
|
162
|
+
If True, only include trials that have a recorded stimulus event.
|
|
163
|
+
require_response : bool, default True
|
|
164
|
+
If True, only include trials that have a recorded response event.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
mne.io.Raw
|
|
169
|
+
The `raw` object with the new annotations set.
|
|
170
|
+
|
|
171
|
+
Raises
|
|
172
|
+
------
|
|
173
|
+
KeyError
|
|
174
|
+
If `target_field` is not a valid column in the built trial table.
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
fnames = raw.filenames
|
|
178
|
+
assert len(fnames) == 1, "Expected a single filename"
|
|
179
|
+
bids_path = get_bids_path_from_fname(fnames[0])
|
|
180
|
+
events_file = bids_path.update(suffix="events", extension=".tsv").fpath
|
|
181
|
+
|
|
182
|
+
events_df = (
|
|
183
|
+
pd.read_csv(events_file, sep="\t")
|
|
184
|
+
.assign(onset=lambda d: pd.to_numeric(d["onset"], errors="raise"))
|
|
185
|
+
.sort_values("onset", kind="mergesort")
|
|
186
|
+
.reset_index(drop=True)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
trials = build_trial_table(events_df)
|
|
190
|
+
|
|
191
|
+
if require_stimulus:
|
|
192
|
+
trials = trials[trials["stimulus_onset"].notna()].copy()
|
|
193
|
+
if require_response:
|
|
194
|
+
trials = trials[trials["response_onset"].notna()].copy()
|
|
195
|
+
|
|
196
|
+
if target_field not in trials.columns:
|
|
197
|
+
raise KeyError(f"{target_field} not in computed trial table.")
|
|
198
|
+
targets = trials[target_field].astype(float)
|
|
199
|
+
|
|
200
|
+
onsets = trials["trial_start_onset"].to_numpy(float)
|
|
201
|
+
durations = np.full(len(trials), float(epoch_length), dtype=float)
|
|
202
|
+
descs = ["contrast_trial_start"] * len(trials)
|
|
203
|
+
|
|
204
|
+
extras = []
|
|
205
|
+
for i, v in enumerate(targets):
|
|
206
|
+
row = trials.iloc[i]
|
|
207
|
+
extras.append(
|
|
208
|
+
{
|
|
209
|
+
"target": _to_float_or_none(v),
|
|
210
|
+
"rt_from_stimulus": _to_float_or_none(row["rt_from_stimulus"]),
|
|
211
|
+
"rt_from_trialstart": _to_float_or_none(row["rt_from_trialstart"]),
|
|
212
|
+
"stimulus_onset": _to_float_or_none(row["stimulus_onset"]),
|
|
213
|
+
"response_onset": _to_float_or_none(row["response_onset"]),
|
|
214
|
+
"correct": _to_int_or_none(row["correct"]),
|
|
215
|
+
"response_type": _to_str_or_none(row["response_type"]),
|
|
216
|
+
}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
new_ann = mne.Annotations(
|
|
220
|
+
onset=onsets,
|
|
221
|
+
duration=durations,
|
|
222
|
+
description=descs,
|
|
223
|
+
orig_time=raw.info.get("meas_date"),
|
|
224
|
+
extras=extras,
|
|
225
|
+
)
|
|
226
|
+
raw.set_annotations(new_ann, verbose=False)
|
|
227
|
+
return raw
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def add_aux_anchors(
|
|
231
|
+
raw: mne.io.Raw,
|
|
232
|
+
stim_desc: str = "stimulus_anchor",
|
|
233
|
+
resp_desc: str = "response_anchor",
|
|
234
|
+
) -> mne.io.Raw:
|
|
235
|
+
"""Add auxiliary annotations for stimulus and response onsets.
|
|
236
|
+
|
|
237
|
+
This function inspects existing "contrast_trial_start" annotations and
|
|
238
|
+
adds new, zero-duration "anchor" annotations at the precise onsets of
|
|
239
|
+
stimuli and responses for each trial.
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
raw : mne.io.Raw
|
|
244
|
+
The raw data object with "contrast_trial_start" annotations.
|
|
245
|
+
stim_desc : str, default "stimulus_anchor"
|
|
246
|
+
The description for the new stimulus annotations.
|
|
247
|
+
resp_desc : str, default "response_anchor"
|
|
248
|
+
The description for the new response annotations.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
mne.io.Raw
|
|
253
|
+
The `raw` object with the auxiliary annotations added.
|
|
254
|
+
|
|
255
|
+
"""
|
|
256
|
+
ann = raw.annotations
|
|
257
|
+
mask = ann.description == "contrast_trial_start"
|
|
258
|
+
if not np.any(mask):
|
|
259
|
+
return raw
|
|
260
|
+
|
|
261
|
+
stim_onsets, resp_onsets = [], []
|
|
262
|
+
stim_extras, resp_extras = [], []
|
|
263
|
+
|
|
264
|
+
for idx in np.where(mask)[0]:
|
|
265
|
+
ex = ann.extras[idx] if ann.extras is not None else {}
|
|
266
|
+
t0 = float(ann.onset[idx])
|
|
267
|
+
|
|
268
|
+
stim_t = ex.get("stimulus_onset")
|
|
269
|
+
resp_t = ex.get("response_onset")
|
|
270
|
+
|
|
271
|
+
if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)):
|
|
272
|
+
rtt = ex.get("rt_from_trialstart")
|
|
273
|
+
rts = ex.get("rt_from_stimulus")
|
|
274
|
+
if rtt is not None and rts is not None:
|
|
275
|
+
stim_t = t0 + float(rtt) - float(rts)
|
|
276
|
+
|
|
277
|
+
if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)):
|
|
278
|
+
rtt = ex.get("rt_from_trialstart")
|
|
279
|
+
if rtt is not None:
|
|
280
|
+
resp_t = t0 + float(rtt)
|
|
281
|
+
|
|
282
|
+
if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)):
|
|
283
|
+
stim_onsets.append(float(stim_t))
|
|
284
|
+
stim_extras.append(dict(ex, anchor="stimulus"))
|
|
285
|
+
if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)):
|
|
286
|
+
resp_onsets.append(float(resp_t))
|
|
287
|
+
resp_extras.append(dict(ex, anchor="response"))
|
|
288
|
+
|
|
289
|
+
new_onsets = np.array(stim_onsets + resp_onsets, dtype=float)
|
|
290
|
+
if len(new_onsets):
|
|
291
|
+
aux = mne.Annotations(
|
|
292
|
+
onset=new_onsets,
|
|
293
|
+
duration=np.zeros_like(new_onsets, dtype=float),
|
|
294
|
+
description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets),
|
|
295
|
+
orig_time=raw.info.get("meas_date"),
|
|
296
|
+
extras=stim_extras + resp_extras,
|
|
297
|
+
)
|
|
298
|
+
raw.set_annotations(ann + aux, verbose=False)
|
|
299
|
+
return raw
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def add_extras_columns(
|
|
303
|
+
windows_concat_ds: BaseConcatDataset,
|
|
304
|
+
original_concat_ds: BaseConcatDataset,
|
|
305
|
+
desc: str = "contrast_trial_start",
|
|
306
|
+
keys: tuple = (
|
|
307
|
+
"target",
|
|
308
|
+
"rt_from_stimulus",
|
|
309
|
+
"rt_from_trialstart",
|
|
310
|
+
"stimulus_onset",
|
|
311
|
+
"response_onset",
|
|
312
|
+
"correct",
|
|
313
|
+
"response_type",
|
|
314
|
+
),
|
|
315
|
+
) -> BaseConcatDataset:
|
|
316
|
+
"""Add columns from annotation extras to a windowed dataset's metadata.
|
|
317
|
+
|
|
318
|
+
This function propagates trial-level information stored in the `extras`
|
|
319
|
+
of annotations to the `metadata` DataFrame of a `WindowsDataset`.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
windows_concat_ds : BaseConcatDataset
|
|
324
|
+
The windowed dataset whose metadata will be updated.
|
|
325
|
+
original_concat_ds : BaseConcatDataset
|
|
326
|
+
The original (non-windowed) dataset containing the raw data and
|
|
327
|
+
annotations with the `extras` to be added.
|
|
328
|
+
desc : str, default "contrast_trial_start"
|
|
329
|
+
The description of the annotations to source the extras from.
|
|
330
|
+
keys : tuple, default (...)
|
|
331
|
+
The keys to extract from each annotation's `extras` dictionary and
|
|
332
|
+
add as columns to the metadata.
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
BaseConcatDataset
|
|
337
|
+
The `windows_concat_ds` with updated metadata.
|
|
338
|
+
|
|
339
|
+
"""
|
|
340
|
+
float_cols = {
|
|
341
|
+
"target",
|
|
342
|
+
"rt_from_stimulus",
|
|
343
|
+
"rt_from_trialstart",
|
|
344
|
+
"stimulus_onset",
|
|
345
|
+
"response_onset",
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
for win_ds, base_ds in zip(windows_concat_ds.datasets, original_concat_ds.datasets):
|
|
349
|
+
ann = base_ds.raw.annotations
|
|
350
|
+
idx = np.where(ann.description == desc)[0]
|
|
351
|
+
if idx.size == 0:
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
per_trial = [
|
|
355
|
+
{
|
|
356
|
+
k: (
|
|
357
|
+
ann.extras[i][k]
|
|
358
|
+
if ann.extras is not None and k in ann.extras[i]
|
|
359
|
+
else None
|
|
360
|
+
)
|
|
361
|
+
for k in keys
|
|
362
|
+
}
|
|
363
|
+
for i in idx
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
md = win_ds.metadata.copy()
|
|
367
|
+
first = md["i_window_in_trial"].to_numpy() == 0
|
|
368
|
+
trial_ids = first.cumsum() - 1
|
|
369
|
+
n_trials = trial_ids.max() + 1 if len(trial_ids) else 0
|
|
370
|
+
assert n_trials == len(per_trial), (
|
|
371
|
+
f"Trial mismatch: {n_trials} vs {len(per_trial)}"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
for k in keys:
|
|
375
|
+
vals = [per_trial[t][k] if t < len(per_trial) else None for t in trial_ids]
|
|
376
|
+
if k == "correct":
|
|
377
|
+
ser = pd.Series(
|
|
378
|
+
[None if v is None else int(bool(v)) for v in vals],
|
|
379
|
+
index=md.index,
|
|
380
|
+
dtype="Int64",
|
|
381
|
+
)
|
|
382
|
+
elif k in float_cols:
|
|
383
|
+
ser = pd.Series(
|
|
384
|
+
[np.nan if v is None else float(v) for v in vals],
|
|
385
|
+
index=md.index,
|
|
386
|
+
dtype="Float64",
|
|
387
|
+
)
|
|
388
|
+
else: # response_type
|
|
389
|
+
ser = pd.Series(vals, index=md.index, dtype="string")
|
|
390
|
+
|
|
391
|
+
md[k] = ser
|
|
392
|
+
|
|
393
|
+
win_ds.metadata = md.reset_index(drop=True)
|
|
394
|
+
if hasattr(win_ds, "y"):
|
|
395
|
+
y_np = win_ds.metadata["target"].astype(float).to_numpy()
|
|
396
|
+
win_ds.y = y_np[:, None] # (N, 1)
|
|
397
|
+
|
|
398
|
+
return windows_concat_ds
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def keep_only_recordings_with(
|
|
402
|
+
desc: str, concat_ds: BaseConcatDataset
|
|
403
|
+
) -> BaseConcatDataset:
|
|
404
|
+
"""Filter a concatenated dataset to keep only recordings with a specific annotation.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
desc : str
|
|
409
|
+
The description of the annotation that must be present in a recording
|
|
410
|
+
for it to be kept.
|
|
411
|
+
concat_ds : BaseConcatDataset
|
|
412
|
+
The concatenated dataset to filter.
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
BaseConcatDataset
|
|
417
|
+
A new concatenated dataset containing only the filtered recordings.
|
|
418
|
+
|
|
419
|
+
"""
|
|
420
|
+
kept = []
|
|
421
|
+
for ds in concat_ds.datasets:
|
|
422
|
+
if np.any(ds.raw.annotations.description == desc):
|
|
423
|
+
kept.append(ds)
|
|
424
|
+
else:
|
|
425
|
+
logging.warning(
|
|
426
|
+
f"Recording {ds.raw.filenames[0]} does not contain event '{desc}'"
|
|
427
|
+
)
|
|
428
|
+
return BaseConcatDataset(kept)
|
eegdash/logging.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Authors: The EEGDash contributors.
|
|
2
|
+
# License: BSD-3-Clause
|
|
3
|
+
# Copyright the EEGDash contributors.
|
|
4
|
+
|
|
5
|
+
"""Logging configuration for EEGDash.
|
|
6
|
+
|
|
7
|
+
This module sets up centralized logging for the EEGDash package using Rich for enhanced
|
|
8
|
+
console output formatting. It provides a consistent logging interface across all modules.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from rich.logging import RichHandler
|
|
14
|
+
|
|
15
|
+
# Get the root logger
|
|
16
|
+
root_logger = logging.getLogger()
|
|
17
|
+
|
|
18
|
+
# --- This is the key part ---
|
|
19
|
+
# 1. Remove any handlers that may have been added by default
|
|
20
|
+
root_logger.handlers = []
|
|
21
|
+
|
|
22
|
+
# 2. Add your RichHandler
|
|
23
|
+
root_logger.addHandler(RichHandler(rich_tracebacks=True, markup=True))
|
|
24
|
+
# ---------------------------
|
|
25
|
+
|
|
26
|
+
# 3. Set the level for the root logger
|
|
27
|
+
root_logger.setLevel(logging.INFO)
|
|
28
|
+
|
|
29
|
+
# Now, get your package-specific logger. It will inherit the
|
|
30
|
+
# configuration from the root logger we just set up.
|
|
31
|
+
logger = logging.getLogger("eegdash")
|
|
32
|
+
"""The primary logger for the EEGDash package.
|
|
33
|
+
|
|
34
|
+
This logger is configured to use :class:`rich.logging.RichHandler` for
|
|
35
|
+
formatted, colorful output in the console. It inherits its base configuration
|
|
36
|
+
from the root logger, which is set to the ``INFO`` level.
|
|
37
|
+
|
|
38
|
+
Examples
|
|
39
|
+
--------
|
|
40
|
+
Usage in other modules:
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
from .logging import logger
|
|
45
|
+
|
|
46
|
+
logger.info("This is an informational message.")
|
|
47
|
+
logger.warning("This is a warning.")
|
|
48
|
+
logger.error("This is an error.")
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
logger.setLevel(logging.INFO)
|
|
53
|
+
|
|
54
|
+
__all__ = ["logger"]
|
eegdash/mongodb.py
CHANGED
|
@@ -1,42 +1,66 @@
|
|
|
1
|
+
# Authors: The EEGDash contributors.
|
|
2
|
+
# License: BSD-3-Clause
|
|
3
|
+
# Copyright the EEGDash contributors.
|
|
4
|
+
|
|
5
|
+
"""MongoDB connection and operations management.
|
|
6
|
+
|
|
7
|
+
This module provides a thread-safe singleton manager for MongoDB connections,
|
|
8
|
+
ensuring that connections to the database are handled efficiently and consistently
|
|
9
|
+
across the application.
|
|
10
|
+
"""
|
|
11
|
+
|
|
1
12
|
import threading
|
|
2
13
|
|
|
3
14
|
from pymongo import MongoClient
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
# These methods provide a high-level interface to interact with the MongoDB
|
|
7
|
-
# collection, allowing users to find, add, and update EEG data records.
|
|
8
|
-
# - find:
|
|
9
|
-
# - exist:
|
|
10
|
-
# - add_request:
|
|
11
|
-
# - add:
|
|
12
|
-
# - update_request:
|
|
13
|
-
# - remove_field:
|
|
14
|
-
# - remove_field_from_db:
|
|
15
|
-
# - close: Close the MongoDB connection.
|
|
16
|
-
# - __del__: Destructor to close the MongoDB connection.
|
|
15
|
+
from pymongo.collection import Collection
|
|
16
|
+
from pymongo.database import Database
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class MongoConnectionManager:
|
|
20
|
-
"""
|
|
20
|
+
"""A thread-safe singleton to manage MongoDB client connections.
|
|
21
|
+
|
|
22
|
+
This class ensures that only one connection instance is created for each
|
|
23
|
+
unique combination of a connection string and staging flag. It provides
|
|
24
|
+
class methods to get a client and to close all active connections.
|
|
25
|
+
|
|
26
|
+
Attributes
|
|
27
|
+
----------
|
|
28
|
+
_instances : dict
|
|
29
|
+
A dictionary to store singleton instances, mapping a
|
|
30
|
+
(connection_string, is_staging) tuple to a (client, db, collection)
|
|
31
|
+
tuple.
|
|
32
|
+
_lock : threading.Lock
|
|
33
|
+
A lock to ensure thread-safe instantiation of clients.
|
|
21
34
|
|
|
22
|
-
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_instances: dict[tuple[str, bool], tuple[MongoClient, Database, Collection]] = {}
|
|
23
38
|
_lock = threading.Lock()
|
|
24
39
|
|
|
25
40
|
@classmethod
|
|
26
|
-
def get_client(
|
|
27
|
-
|
|
41
|
+
def get_client(
|
|
42
|
+
cls, connection_string: str, is_staging: bool = False
|
|
43
|
+
) -> tuple[MongoClient, Database, Collection]:
|
|
44
|
+
"""Get or create a MongoDB client for the given connection parameters.
|
|
45
|
+
|
|
46
|
+
This method returns a cached client if one already exists for the given
|
|
47
|
+
connection string and staging flag. Otherwise, it creates a new client,
|
|
48
|
+
connects to the appropriate database ("eegdash" or "eegdashstaging"),
|
|
49
|
+
and returns the client, database, and "records" collection.
|
|
28
50
|
|
|
29
51
|
Parameters
|
|
30
52
|
----------
|
|
31
53
|
connection_string : str
|
|
32
|
-
The MongoDB connection string
|
|
33
|
-
is_staging : bool
|
|
34
|
-
|
|
54
|
+
The MongoDB connection string.
|
|
55
|
+
is_staging : bool, default False
|
|
56
|
+
If True, connect to the staging database ("eegdashstaging").
|
|
57
|
+
Otherwise, connect to the production database ("eegdash").
|
|
35
58
|
|
|
36
59
|
Returns
|
|
37
60
|
-------
|
|
38
|
-
tuple
|
|
39
|
-
A tuple
|
|
61
|
+
tuple[MongoClient, Database, Collection]
|
|
62
|
+
A tuple containing the connected MongoClient instance, the Database
|
|
63
|
+
object, and the Collection object for the "records" collection.
|
|
40
64
|
|
|
41
65
|
"""
|
|
42
66
|
# Create a unique key based on connection string and staging flag
|
|
@@ -55,8 +79,12 @@ class MongoConnectionManager:
|
|
|
55
79
|
return cls._instances[key]
|
|
56
80
|
|
|
57
81
|
@classmethod
|
|
58
|
-
def close_all(cls):
|
|
59
|
-
"""Close all MongoDB client connections.
|
|
82
|
+
def close_all(cls) -> None:
|
|
83
|
+
"""Close all managed MongoDB client connections.
|
|
84
|
+
|
|
85
|
+
This method iterates through all cached client instances and closes
|
|
86
|
+
their connections. It also clears the instance cache.
|
|
87
|
+
"""
|
|
60
88
|
with cls._lock:
|
|
61
89
|
for client, _, _ in cls._instances.values():
|
|
62
90
|
try:
|
|
@@ -64,3 +92,6 @@ class MongoConnectionManager:
|
|
|
64
92
|
except Exception:
|
|
65
93
|
pass
|
|
66
94
|
cls._instances.clear()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
__all__ = ["MongoConnectionManager"]
|
eegdash/paths.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Authors: The EEGDash contributors.
|
|
2
|
+
# License: BSD-3-Clause
|
|
3
|
+
# Copyright the EEGDash contributors.
|
|
4
|
+
|
|
5
|
+
"""Path utilities and cache directory management.
|
|
6
|
+
|
|
7
|
+
This module provides functions for resolving consistent cache directories and path
|
|
8
|
+
management throughout the EEGDash package, with integration to MNE-Python's
|
|
9
|
+
configuration system.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from mne.utils import get_config as mne_get_config
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_default_cache_dir() -> Path:
|
|
21
|
+
"""Resolve the default cache directory for EEGDash data.
|
|
22
|
+
|
|
23
|
+
The function determines the cache directory based on the following
|
|
24
|
+
priority order:
|
|
25
|
+
|
|
26
|
+
1. The path specified by the ``EEGDASH_CACHE_DIR`` environment variable.
|
|
27
|
+
2. The path specified by the ``MNE_DATA`` configuration in the MNE-Python
|
|
28
|
+
config file.
|
|
29
|
+
3. A hidden directory named ``.eegdash_cache`` in the current working
|
|
30
|
+
directory.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
pathlib.Path
|
|
35
|
+
The resolved, absolute path to the default cache directory.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
# 1) Explicit env var wins
|
|
39
|
+
env_dir = os.environ.get("EEGDASH_CACHE_DIR")
|
|
40
|
+
if env_dir:
|
|
41
|
+
return Path(env_dir).expanduser().resolve()
|
|
42
|
+
|
|
43
|
+
# 2) Reuse MNE's data cache location if configured
|
|
44
|
+
mne_data = mne_get_config("MNE_DATA")
|
|
45
|
+
if mne_data:
|
|
46
|
+
return Path(mne_data).expanduser().resolve()
|
|
47
|
+
|
|
48
|
+
# 3) Default to a project-local hidden folder
|
|
49
|
+
return Path.cwd() / ".eegdash_cache"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = ["get_default_cache_dir"]
|
eegdash/utils.py
CHANGED
|
@@ -1,7 +1,32 @@
|
|
|
1
|
+
# Authors: The EEGDash contributors.
|
|
2
|
+
# License: BSD-3-Clause
|
|
3
|
+
# Copyright the EEGDash contributors.
|
|
4
|
+
|
|
5
|
+
"""General utility functions for EEGDash.
|
|
6
|
+
|
|
7
|
+
This module contains miscellaneous utility functions used across the EEGDash package,
|
|
8
|
+
including MongoDB client initialization and configuration helpers.
|
|
9
|
+
"""
|
|
10
|
+
|
|
1
11
|
from mne.utils import get_config, set_config, use_log_level
|
|
2
12
|
|
|
3
13
|
|
|
4
|
-
def
|
|
14
|
+
def _init_mongo_client() -> None:
|
|
15
|
+
"""Initialize the default MongoDB connection URI in the MNE config.
|
|
16
|
+
|
|
17
|
+
This function checks if the ``EEGDASH_DB_URI`` is already set in the
|
|
18
|
+
MNE-Python configuration. If it is not set, this function sets it to the
|
|
19
|
+
default public EEGDash MongoDB Atlas cluster URI.
|
|
20
|
+
|
|
21
|
+
The operation is performed with MNE's logging level temporarily set to
|
|
22
|
+
"ERROR" to suppress verbose output.
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
This is an internal helper function and is not intended for direct use
|
|
27
|
+
by end-users.
|
|
28
|
+
|
|
29
|
+
"""
|
|
5
30
|
with use_log_level("ERROR"):
|
|
6
31
|
if get_config("EEGDASH_DB_URI") is None:
|
|
7
32
|
set_config(
|
|
@@ -9,3 +34,6 @@ def __init__mongo_client():
|
|
|
9
34
|
"mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0",
|
|
10
35
|
set_env=True,
|
|
11
36
|
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
__all__ = ["_init_mongo_client"]
|