eegdash 0.3.3.dev61__py3-none-any.whl → 0.5.0.dev180784713__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. eegdash/__init__.py +19 -6
  2. eegdash/api.py +336 -539
  3. eegdash/bids_eeg_metadata.py +495 -0
  4. eegdash/const.py +349 -0
  5. eegdash/dataset/__init__.py +28 -0
  6. eegdash/dataset/base.py +311 -0
  7. eegdash/dataset/bids_dataset.py +641 -0
  8. eegdash/dataset/dataset.py +692 -0
  9. eegdash/dataset/dataset_summary.csv +255 -0
  10. eegdash/dataset/registry.py +287 -0
  11. eegdash/downloader.py +197 -0
  12. eegdash/features/__init__.py +15 -13
  13. eegdash/features/datasets.py +329 -138
  14. eegdash/features/decorators.py +105 -13
  15. eegdash/features/extractors.py +233 -63
  16. eegdash/features/feature_bank/__init__.py +12 -12
  17. eegdash/features/feature_bank/complexity.py +22 -20
  18. eegdash/features/feature_bank/connectivity.py +27 -28
  19. eegdash/features/feature_bank/csp.py +3 -1
  20. eegdash/features/feature_bank/dimensionality.py +6 -6
  21. eegdash/features/feature_bank/signal.py +29 -30
  22. eegdash/features/feature_bank/spectral.py +40 -44
  23. eegdash/features/feature_bank/utils.py +8 -0
  24. eegdash/features/inspect.py +126 -15
  25. eegdash/features/serialization.py +58 -17
  26. eegdash/features/utils.py +90 -16
  27. eegdash/hbn/__init__.py +28 -0
  28. eegdash/hbn/preprocessing.py +105 -0
  29. eegdash/hbn/windows.py +428 -0
  30. eegdash/logging.py +54 -0
  31. eegdash/mongodb.py +55 -24
  32. eegdash/paths.py +52 -0
  33. eegdash/utils.py +29 -1
  34. eegdash-0.5.0.dev180784713.dist-info/METADATA +121 -0
  35. eegdash-0.5.0.dev180784713.dist-info/RECORD +38 -0
  36. eegdash-0.5.0.dev180784713.dist-info/licenses/LICENSE +29 -0
  37. eegdash/data_config.py +0 -34
  38. eegdash/data_utils.py +0 -687
  39. eegdash/dataset.py +0 -69
  40. eegdash/preprocessing.py +0 -63
  41. eegdash-0.3.3.dev61.dist-info/METADATA +0 -192
  42. eegdash-0.3.3.dev61.dist-info/RECORD +0 -28
  43. eegdash-0.3.3.dev61.dist-info/licenses/LICENSE +0 -23
  44. {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/WHEEL +0 -0
  45. {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/top_level.txt +0 -0
eegdash/hbn/windows.py ADDED
@@ -0,0 +1,428 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Windowing and trial processing utilities for HBN datasets.
6
+
7
+ This module provides functions for building trial tables, adding auxiliary anchors,
8
+ annotating trials with targets, and filtering recordings based on various criteria.
9
+ These utilities are specifically designed for working with HBN EEG data structures
10
+ and experimental paradigms.
11
+ """
12
+
13
+ import logging
14
+
15
+ import mne
16
+ import numpy as np
17
+ import pandas as pd
18
+ from mne_bids import get_bids_path_from_fname
19
+
20
+ from braindecode.datasets.base import BaseConcatDataset
21
+
22
+
23
+ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
24
+ """Build a table of contrast trials from an events DataFrame.
25
+
26
+ This function processes a DataFrame of events (typically from a BIDS
27
+ `events.tsv` file) to identify contrast trials and extract relevant
28
+ metrics like stimulus onset, response onset, and reaction times.
29
+
30
+ Parameters
31
+ ----------
32
+ events_df : pandas.DataFrame
33
+ A DataFrame containing event information, with at least "onset" and
34
+ "value" columns.
35
+
36
+ Returns
37
+ -------
38
+ pandas.DataFrame
39
+ A DataFrame where each row represents a single contrast trial, with
40
+ columns for onsets, reaction times, and response correctness.
41
+
42
+ """
43
+ events_df = events_df.copy()
44
+ events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise")
45
+ events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True)
46
+
47
+ trials = events_df[events_df["value"].eq("contrastTrial_start")].copy()
48
+ stimuli = events_df[events_df["value"].isin(["left_target", "right_target"])].copy()
49
+ responses = events_df[
50
+ events_df["value"].isin(["left_buttonPress", "right_buttonPress"])
51
+ ].copy()
52
+
53
+ trials = trials.reset_index(drop=True)
54
+ trials["next_onset"] = trials["onset"].shift(-1)
55
+ trials = trials.dropna(subset=["next_onset"]).reset_index(drop=True)
56
+
57
+ rows = []
58
+ for _, tr in trials.iterrows():
59
+ start = float(tr["onset"])
60
+ end = float(tr["next_onset"])
61
+
62
+ stim_block = stimuli[(stimuli["onset"] >= start) & (stimuli["onset"] < end)]
63
+ stim_onset = np.nan if stim_block.empty else float(stim_block.iloc[0]["onset"])
64
+
65
+ if not np.isnan(stim_onset):
66
+ resp_block = responses[
67
+ (responses["onset"] >= stim_onset) & (responses["onset"] < end)
68
+ ]
69
+ else:
70
+ resp_block = responses[
71
+ (responses["onset"] >= start) & (responses["onset"] < end)
72
+ ]
73
+
74
+ if resp_block.empty:
75
+ resp_onset = np.nan
76
+ resp_type = None
77
+ feedback = None
78
+ else:
79
+ resp_onset = float(resp_block.iloc[0]["onset"])
80
+ resp_type = resp_block.iloc[0]["value"]
81
+ feedback = resp_block.iloc[0]["feedback"]
82
+
83
+ rt_from_stim = (
84
+ (resp_onset - stim_onset)
85
+ if (not np.isnan(stim_onset) and not np.isnan(resp_onset))
86
+ else np.nan
87
+ )
88
+ rt_from_trial = (resp_onset - start) if not np.isnan(resp_onset) else np.nan
89
+
90
+ correct = None
91
+ if isinstance(feedback, str):
92
+ if feedback == "smiley_face":
93
+ correct = True
94
+ elif feedback == "sad_face":
95
+ correct = False
96
+
97
+ rows.append(
98
+ {
99
+ "trial_start_onset": start,
100
+ "trial_stop_onset": end,
101
+ "stimulus_onset": stim_onset,
102
+ "response_onset": resp_onset,
103
+ "rt_from_stimulus": rt_from_stim,
104
+ "rt_from_trialstart": rt_from_trial,
105
+ "response_type": resp_type,
106
+ "correct": correct,
107
+ }
108
+ )
109
+
110
+ return pd.DataFrame(rows)
111
+
112
+
113
+ def _to_float_or_none(x):
114
+ """Safely convert a value to float or None."""
115
+ return None if pd.isna(x) else float(x)
116
+
117
+
118
+ def _to_int_or_none(x):
119
+ """Safely convert a value to int or None."""
120
+ if pd.isna(x):
121
+ return None
122
+ if isinstance(x, (bool, np.bool_)):
123
+ return int(bool(x))
124
+ if isinstance(x, (int, np.integer)):
125
+ return int(x)
126
+ try:
127
+ return int(x)
128
+ except (ValueError, TypeError):
129
+ return None
130
+
131
+
132
+ def _to_str_or_none(x):
133
+ """Safely convert a value to string or None."""
134
+ return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)
135
+
136
+
137
+ def annotate_trials_with_target(
138
+ raw: mne.io.Raw,
139
+ target_field: str = "rt_from_stimulus",
140
+ epoch_length: float = 2.0,
141
+ require_stimulus: bool = True,
142
+ require_response: bool = True,
143
+ ) -> mne.io.Raw:
144
+ """Create trial annotations with a specified target value.
145
+
146
+ This function reads the BIDS events file associated with the `raw` object,
147
+ builds a trial table, and creates new MNE annotations for each trial.
148
+ The annotations are labeled "contrast_trial_start" and their `extras`
149
+ dictionary is populated with trial metrics, including a "target" key.
150
+
151
+ Parameters
152
+ ----------
153
+ raw : mne.io.Raw
154
+ The raw data object. Must have a single associated file name from
155
+ which the BIDS path can be derived.
156
+ target_field : str, default "rt_from_stimulus"
157
+ The column from the trial table to use as the "target" value in the
158
+ annotation extras.
159
+ epoch_length : float, default 2.0
160
+ The duration to set for each new annotation.
161
+ require_stimulus : bool, default True
162
+ If True, only include trials that have a recorded stimulus event.
163
+ require_response : bool, default True
164
+ If True, only include trials that have a recorded response event.
165
+
166
+ Returns
167
+ -------
168
+ mne.io.Raw
169
+ The `raw` object with the new annotations set.
170
+
171
+ Raises
172
+ ------
173
+ KeyError
174
+ If `target_field` is not a valid column in the built trial table.
175
+
176
+ """
177
+ fnames = raw.filenames
178
+ assert len(fnames) == 1, "Expected a single filename"
179
+ bids_path = get_bids_path_from_fname(fnames[0])
180
+ events_file = bids_path.update(suffix="events", extension=".tsv").fpath
181
+
182
+ events_df = (
183
+ pd.read_csv(events_file, sep="\t")
184
+ .assign(onset=lambda d: pd.to_numeric(d["onset"], errors="raise"))
185
+ .sort_values("onset", kind="mergesort")
186
+ .reset_index(drop=True)
187
+ )
188
+
189
+ trials = build_trial_table(events_df)
190
+
191
+ if require_stimulus:
192
+ trials = trials[trials["stimulus_onset"].notna()].copy()
193
+ if require_response:
194
+ trials = trials[trials["response_onset"].notna()].copy()
195
+
196
+ if target_field not in trials.columns:
197
+ raise KeyError(f"{target_field} not in computed trial table.")
198
+ targets = trials[target_field].astype(float)
199
+
200
+ onsets = trials["trial_start_onset"].to_numpy(float)
201
+ durations = np.full(len(trials), float(epoch_length), dtype=float)
202
+ descs = ["contrast_trial_start"] * len(trials)
203
+
204
+ extras = []
205
+ for i, v in enumerate(targets):
206
+ row = trials.iloc[i]
207
+ extras.append(
208
+ {
209
+ "target": _to_float_or_none(v),
210
+ "rt_from_stimulus": _to_float_or_none(row["rt_from_stimulus"]),
211
+ "rt_from_trialstart": _to_float_or_none(row["rt_from_trialstart"]),
212
+ "stimulus_onset": _to_float_or_none(row["stimulus_onset"]),
213
+ "response_onset": _to_float_or_none(row["response_onset"]),
214
+ "correct": _to_int_or_none(row["correct"]),
215
+ "response_type": _to_str_or_none(row["response_type"]),
216
+ }
217
+ )
218
+
219
+ new_ann = mne.Annotations(
220
+ onset=onsets,
221
+ duration=durations,
222
+ description=descs,
223
+ orig_time=raw.info.get("meas_date"),
224
+ extras=extras,
225
+ )
226
+ raw.set_annotations(new_ann, verbose=False)
227
+ return raw
228
+
229
+
230
+ def add_aux_anchors(
231
+ raw: mne.io.Raw,
232
+ stim_desc: str = "stimulus_anchor",
233
+ resp_desc: str = "response_anchor",
234
+ ) -> mne.io.Raw:
235
+ """Add auxiliary annotations for stimulus and response onsets.
236
+
237
+ This function inspects existing "contrast_trial_start" annotations and
238
+ adds new, zero-duration "anchor" annotations at the precise onsets of
239
+ stimuli and responses for each trial.
240
+
241
+ Parameters
242
+ ----------
243
+ raw : mne.io.Raw
244
+ The raw data object with "contrast_trial_start" annotations.
245
+ stim_desc : str, default "stimulus_anchor"
246
+ The description for the new stimulus annotations.
247
+ resp_desc : str, default "response_anchor"
248
+ The description for the new response annotations.
249
+
250
+ Returns
251
+ -------
252
+ mne.io.Raw
253
+ The `raw` object with the auxiliary annotations added.
254
+
255
+ """
256
+ ann = raw.annotations
257
+ mask = ann.description == "contrast_trial_start"
258
+ if not np.any(mask):
259
+ return raw
260
+
261
+ stim_onsets, resp_onsets = [], []
262
+ stim_extras, resp_extras = [], []
263
+
264
+ for idx in np.where(mask)[0]:
265
+ ex = ann.extras[idx] if ann.extras is not None else {}
266
+ t0 = float(ann.onset[idx])
267
+
268
+ stim_t = ex.get("stimulus_onset")
269
+ resp_t = ex.get("response_onset")
270
+
271
+ if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)):
272
+ rtt = ex.get("rt_from_trialstart")
273
+ rts = ex.get("rt_from_stimulus")
274
+ if rtt is not None and rts is not None:
275
+ stim_t = t0 + float(rtt) - float(rts)
276
+
277
+ if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)):
278
+ rtt = ex.get("rt_from_trialstart")
279
+ if rtt is not None:
280
+ resp_t = t0 + float(rtt)
281
+
282
+ if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)):
283
+ stim_onsets.append(float(stim_t))
284
+ stim_extras.append(dict(ex, anchor="stimulus"))
285
+ if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)):
286
+ resp_onsets.append(float(resp_t))
287
+ resp_extras.append(dict(ex, anchor="response"))
288
+
289
+ new_onsets = np.array(stim_onsets + resp_onsets, dtype=float)
290
+ if len(new_onsets):
291
+ aux = mne.Annotations(
292
+ onset=new_onsets,
293
+ duration=np.zeros_like(new_onsets, dtype=float),
294
+ description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets),
295
+ orig_time=raw.info.get("meas_date"),
296
+ extras=stim_extras + resp_extras,
297
+ )
298
+ raw.set_annotations(ann + aux, verbose=False)
299
+ return raw
300
+
301
+
302
+ def add_extras_columns(
303
+ windows_concat_ds: BaseConcatDataset,
304
+ original_concat_ds: BaseConcatDataset,
305
+ desc: str = "contrast_trial_start",
306
+ keys: tuple = (
307
+ "target",
308
+ "rt_from_stimulus",
309
+ "rt_from_trialstart",
310
+ "stimulus_onset",
311
+ "response_onset",
312
+ "correct",
313
+ "response_type",
314
+ ),
315
+ ) -> BaseConcatDataset:
316
+ """Add columns from annotation extras to a windowed dataset's metadata.
317
+
318
+ This function propagates trial-level information stored in the `extras`
319
+ of annotations to the `metadata` DataFrame of a `WindowsDataset`.
320
+
321
+ Parameters
322
+ ----------
323
+ windows_concat_ds : BaseConcatDataset
324
+ The windowed dataset whose metadata will be updated.
325
+ original_concat_ds : BaseConcatDataset
326
+ The original (non-windowed) dataset containing the raw data and
327
+ annotations with the `extras` to be added.
328
+ desc : str, default "contrast_trial_start"
329
+ The description of the annotations to source the extras from.
330
+ keys : tuple, default (...)
331
+ The keys to extract from each annotation's `extras` dictionary and
332
+ add as columns to the metadata.
333
+
334
+ Returns
335
+ -------
336
+ BaseConcatDataset
337
+ The `windows_concat_ds` with updated metadata.
338
+
339
+ """
340
+ float_cols = {
341
+ "target",
342
+ "rt_from_stimulus",
343
+ "rt_from_trialstart",
344
+ "stimulus_onset",
345
+ "response_onset",
346
+ }
347
+
348
+ for win_ds, base_ds in zip(windows_concat_ds.datasets, original_concat_ds.datasets):
349
+ ann = base_ds.raw.annotations
350
+ idx = np.where(ann.description == desc)[0]
351
+ if idx.size == 0:
352
+ continue
353
+
354
+ per_trial = [
355
+ {
356
+ k: (
357
+ ann.extras[i][k]
358
+ if ann.extras is not None and k in ann.extras[i]
359
+ else None
360
+ )
361
+ for k in keys
362
+ }
363
+ for i in idx
364
+ ]
365
+
366
+ md = win_ds.metadata.copy()
367
+ first = md["i_window_in_trial"].to_numpy() == 0
368
+ trial_ids = first.cumsum() - 1
369
+ n_trials = trial_ids.max() + 1 if len(trial_ids) else 0
370
+ assert n_trials == len(per_trial), (
371
+ f"Trial mismatch: {n_trials} vs {len(per_trial)}"
372
+ )
373
+
374
+ for k in keys:
375
+ vals = [per_trial[t][k] if t < len(per_trial) else None for t in trial_ids]
376
+ if k == "correct":
377
+ ser = pd.Series(
378
+ [None if v is None else int(bool(v)) for v in vals],
379
+ index=md.index,
380
+ dtype="Int64",
381
+ )
382
+ elif k in float_cols:
383
+ ser = pd.Series(
384
+ [np.nan if v is None else float(v) for v in vals],
385
+ index=md.index,
386
+ dtype="Float64",
387
+ )
388
+ else: # response_type
389
+ ser = pd.Series(vals, index=md.index, dtype="string")
390
+
391
+ md[k] = ser
392
+
393
+ win_ds.metadata = md.reset_index(drop=True)
394
+ if hasattr(win_ds, "y"):
395
+ y_np = win_ds.metadata["target"].astype(float).to_numpy()
396
+ win_ds.y = y_np[:, None] # (N, 1)
397
+
398
+ return windows_concat_ds
399
+
400
+
401
+ def keep_only_recordings_with(
402
+ desc: str, concat_ds: BaseConcatDataset
403
+ ) -> BaseConcatDataset:
404
+ """Filter a concatenated dataset to keep only recordings with a specific annotation.
405
+
406
+ Parameters
407
+ ----------
408
+ desc : str
409
+ The description of the annotation that must be present in a recording
410
+ for it to be kept.
411
+ concat_ds : BaseConcatDataset
412
+ The concatenated dataset to filter.
413
+
414
+ Returns
415
+ -------
416
+ BaseConcatDataset
417
+ A new concatenated dataset containing only the filtered recordings.
418
+
419
+ """
420
+ kept = []
421
+ for ds in concat_ds.datasets:
422
+ if np.any(ds.raw.annotations.description == desc):
423
+ kept.append(ds)
424
+ else:
425
+ logging.warning(
426
+ f"Recording {ds.raw.filenames[0]} does not contain event '{desc}'"
427
+ )
428
+ return BaseConcatDataset(kept)
eegdash/logging.py ADDED
@@ -0,0 +1,54 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Logging configuration for EEGDash.
6
+
7
+ This module sets up centralized logging for the EEGDash package using Rich for enhanced
8
+ console output formatting. It provides a consistent logging interface across all modules.
9
+ """
10
+
11
+ import logging
12
+
13
+ from rich.logging import RichHandler
14
+
15
+ # Get the root logger
16
+ root_logger = logging.getLogger()
17
+
18
+ # --- This is the key part ---
19
+ # 1. Remove any handlers that may have been added by default
20
+ root_logger.handlers = []
21
+
22
+ # 2. Add your RichHandler
23
+ root_logger.addHandler(RichHandler(rich_tracebacks=True, markup=True))
24
+ # ---------------------------
25
+
26
+ # 3. Set the level for the root logger
27
+ root_logger.setLevel(logging.INFO)
28
+
29
+ # Now, get your package-specific logger. It will inherit the
30
+ # configuration from the root logger we just set up.
31
+ logger = logging.getLogger("eegdash")
32
+ """The primary logger for the EEGDash package.
33
+
34
+ This logger is configured to use :class:`rich.logging.RichHandler` for
35
+ formatted, colorful output in the console. It inherits its base configuration
36
+ from the root logger, which is set to the ``INFO`` level.
37
+
38
+ Examples
39
+ --------
40
+ Usage in other modules:
41
+
42
+ .. code-block:: python
43
+
44
+ from .logging import logger
45
+
46
+ logger.info("This is an informational message.")
47
+ logger.warning("This is a warning.")
48
+ logger.error("This is an error.")
49
+ """
50
+
51
+
52
+ logger.setLevel(logging.INFO)
53
+
54
+ __all__ = ["logger"]
eegdash/mongodb.py CHANGED
@@ -1,42 +1,66 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """MongoDB connection and operations management.
6
+
7
+ This module provides a thread-safe singleton manager for MongoDB connections,
8
+ ensuring that connections to the database are handled efficiently and consistently
9
+ across the application.
10
+ """
11
+
1
12
  import threading
2
13
 
3
14
  from pymongo import MongoClient
4
-
5
- # MongoDB Operations
6
- # These methods provide a high-level interface to interact with the MongoDB
7
- # collection, allowing users to find, add, and update EEG data records.
8
- # - find:
9
- # - exist:
10
- # - add_request:
11
- # - add:
12
- # - update_request:
13
- # - remove_field:
14
- # - remove_field_from_db:
15
- # - close: Close the MongoDB connection.
16
- # - __del__: Destructor to close the MongoDB connection.
15
+ from pymongo.collection import Collection
16
+ from pymongo.database import Database
17
17
 
18
18
 
19
19
  class MongoConnectionManager:
20
- """Singleton class to manage MongoDB client connections."""
20
+ """A thread-safe singleton to manage MongoDB client connections.
21
+
22
+ This class ensures that only one connection instance is created for each
23
+ unique combination of a connection string and staging flag. It provides
24
+ class methods to get a client and to close all active connections.
25
+
26
+ Attributes
27
+ ----------
28
+ _instances : dict
29
+ A dictionary to store singleton instances, mapping a
30
+ (connection_string, is_staging) tuple to a (client, db, collection)
31
+ tuple.
32
+ _lock : threading.Lock
33
+ A lock to ensure thread-safe instantiation of clients.
21
34
 
22
- _instances = {}
35
+ """
36
+
37
+ _instances: dict[tuple[str, bool], tuple[MongoClient, Database, Collection]] = {}
23
38
  _lock = threading.Lock()
24
39
 
25
40
  @classmethod
26
- def get_client(cls, connection_string: str, is_staging: bool = False):
27
- """Get or create a MongoDB client for the given connection string and staging flag.
41
+ def get_client(
42
+ cls, connection_string: str, is_staging: bool = False
43
+ ) -> tuple[MongoClient, Database, Collection]:
44
+ """Get or create a MongoDB client for the given connection parameters.
45
+
46
+ This method returns a cached client if one already exists for the given
47
+ connection string and staging flag. Otherwise, it creates a new client,
48
+ connects to the appropriate database ("eegdash" or "eegdashstaging"),
49
+ and returns the client, database, and "records" collection.
28
50
 
29
51
  Parameters
30
52
  ----------
31
53
  connection_string : str
32
- The MongoDB connection string
33
- is_staging : bool
34
- Whether to use staging database
54
+ The MongoDB connection string.
55
+ is_staging : bool, default False
56
+ If True, connect to the staging database ("eegdashstaging").
57
+ Otherwise, connect to the production database ("eegdash").
35
58
 
36
59
  Returns
37
60
  -------
38
- tuple
39
- A tuple of (client, database, collection)
61
+ tuple[MongoClient, Database, Collection]
62
+ A tuple containing the connected MongoClient instance, the Database
63
+ object, and the Collection object for the "records" collection.
40
64
 
41
65
  """
42
66
  # Create a unique key based on connection string and staging flag
@@ -55,8 +79,12 @@ class MongoConnectionManager:
55
79
  return cls._instances[key]
56
80
 
57
81
  @classmethod
58
- def close_all(cls):
59
- """Close all MongoDB client connections."""
82
+ def close_all(cls) -> None:
83
+ """Close all managed MongoDB client connections.
84
+
85
+ This method iterates through all cached client instances and closes
86
+ their connections. It also clears the instance cache.
87
+ """
60
88
  with cls._lock:
61
89
  for client, _, _ in cls._instances.values():
62
90
  try:
@@ -64,3 +92,6 @@ class MongoConnectionManager:
64
92
  except Exception:
65
93
  pass
66
94
  cls._instances.clear()
95
+
96
+
97
+ __all__ = ["MongoConnectionManager"]
eegdash/paths.py ADDED
@@ -0,0 +1,52 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Path utilities and cache directory management.
6
+
7
+ This module provides functions for resolving consistent cache directories and path
8
+ management throughout the EEGDash package, with integration to MNE-Python's
9
+ configuration system.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ from pathlib import Path
16
+
17
+ from mne.utils import get_config as mne_get_config
18
+
19
+
20
+ def get_default_cache_dir() -> Path:
21
+ """Resolve the default cache directory for EEGDash data.
22
+
23
+ The function determines the cache directory based on the following
24
+ priority order:
25
+
26
+ 1. The path specified by the ``EEGDASH_CACHE_DIR`` environment variable.
27
+ 2. The path specified by the ``MNE_DATA`` configuration in the MNE-Python
28
+ config file.
29
+ 3. A hidden directory named ``.eegdash_cache`` in the current working
30
+ directory.
31
+
32
+ Returns
33
+ -------
34
+ pathlib.Path
35
+ The resolved, absolute path to the default cache directory.
36
+
37
+ """
38
+ # 1) Explicit env var wins
39
+ env_dir = os.environ.get("EEGDASH_CACHE_DIR")
40
+ if env_dir:
41
+ return Path(env_dir).expanduser().resolve()
42
+
43
+ # 2) Reuse MNE's data cache location if configured
44
+ mne_data = mne_get_config("MNE_DATA")
45
+ if mne_data:
46
+ return Path(mne_data).expanduser().resolve()
47
+
48
+ # 3) Default to a project-local hidden folder
49
+ return Path.cwd() / ".eegdash_cache"
50
+
51
+
52
+ __all__ = ["get_default_cache_dir"]
eegdash/utils.py CHANGED
@@ -1,7 +1,32 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """General utility functions for EEGDash.
6
+
7
+ This module contains miscellaneous utility functions used across the EEGDash package,
8
+ including MongoDB client initialization and configuration helpers.
9
+ """
10
+
1
11
  from mne.utils import get_config, set_config, use_log_level
2
12
 
3
13
 
4
- def __init__mongo_client():
14
+ def _init_mongo_client() -> None:
15
+ """Initialize the default MongoDB connection URI in the MNE config.
16
+
17
+ This function checks if the ``EEGDASH_DB_URI`` is already set in the
18
+ MNE-Python configuration. If it is not set, this function sets it to the
19
+ default public EEGDash MongoDB Atlas cluster URI.
20
+
21
+ The operation is performed with MNE's logging level temporarily set to
22
+ "ERROR" to suppress verbose output.
23
+
24
+ Notes
25
+ -----
26
+ This is an internal helper function and is not intended for direct use
27
+ by end-users.
28
+
29
+ """
5
30
  with use_log_level("ERROR"):
6
31
  if get_config("EEGDASH_DB_URI") is None:
7
32
  set_config(
@@ -9,3 +34,6 @@ def __init__mongo_client():
9
34
  "mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0",
10
35
  set_env=True,
11
36
  )
37
+
38
+
39
+ __all__ = ["_init_mongo_client"]