eegdash 0.3.7.dev104__py3-none-any.whl → 0.3.7.dev105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/api.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
  import tempfile
4
4
  from pathlib import Path
5
5
  from typing import Any, Mapping
6
+ from urllib.parse import urlsplit
6
7
 
7
8
  import mne
8
9
  import numpy as np
@@ -11,14 +12,18 @@ import xarray as xr
11
12
  from dotenv import load_dotenv
12
13
  from joblib import Parallel, delayed
13
14
  from mne.utils import warn
14
- from mne_bids import get_bids_path_from_fname, read_raw_bids
15
+ from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
15
16
  from pymongo import InsertOne, UpdateOne
16
17
  from s3fs import S3FileSystem
17
18
 
18
19
  from braindecode.datasets import BaseConcatDataset
19
20
 
20
- from .const import RELEASE_TO_OPENNEURO_DATASET_MAP
21
- from .data_config import config as data_config
21
+ from .bids_eeg_metadata import build_query_from_kwargs, load_eeg_attrs_from_bids_file
22
+ from .const import (
23
+ ALLOWED_QUERY_FIELDS,
24
+ RELEASE_TO_OPENNEURO_DATASET_MAP,
25
+ )
26
+ from .const import config as data_config
22
27
  from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
23
28
  from .mongodb import MongoConnectionManager
24
29
 
@@ -26,46 +31,31 @@ logger = logging.getLogger("eegdash")
26
31
 
27
32
 
28
33
  class EEGDash:
29
- """A high-level interface to the EEGDash database.
34
+ """High-level interface to the EEGDash metadata database.
30
35
 
31
- This class is primarily used to interact with the metadata records stored in the
32
- EEGDash database (or a private instance of it), allowing users to find, add, and
33
- update EEG data records.
34
-
35
- While this class provides basic support for loading EEG data, please see
36
- the EEGDashDataset class for a more complete way to retrieve and work with full
37
- datasets.
36
+ Provides methods to query, insert, and update metadata records stored in the
37
+ EEGDash MongoDB database (public or private). Also includes utilities to load
38
+ EEG data from S3 for matched records.
38
39
 
40
+ For working with collections of
41
+ recordings as PyTorch datasets, prefer :class:`EEGDashDataset`.
39
42
  """
40
43
 
41
- _ALLOWED_QUERY_FIELDS = {
42
- "data_name",
43
- "dataset",
44
- "subject",
45
- "task",
46
- "session",
47
- "run",
48
- "modality",
49
- "sampling_frequency",
50
- "nchans",
51
- "ntimes",
52
- }
53
-
54
44
  def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
55
- """Create new instance of the EEGDash Database client.
45
+ """Create a new EEGDash client.
56
46
 
57
47
  Parameters
58
48
  ----------
59
- is_public: bool
60
- Whether to connect to the public MongoDB database; if False, connect to a
61
- private database instance as per the DB_CONNECTION_STRING env variable
62
- (or .env file entry).
63
- is_staging: bool
64
- If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
65
- production database ("eegdash").
66
-
67
- Example
68
- -------
49
+ is_public : bool, default True
50
+ Connect to the public MongoDB database. If ``False``, connect to a
51
+ private database instance using the ``DB_CONNECTION_STRING`` environment
52
+ variable (or value from a ``.env`` file).
53
+ is_staging : bool, default False
54
+ If ``True``, use the staging database (``eegdashstaging``); otherwise
55
+ use the production database (``eegdash``).
56
+
57
+ Examples
58
+ --------
69
59
  >>> eegdash = EEGDash()
70
60
 
71
61
  """
@@ -106,23 +96,25 @@ class EEGDash:
106
96
 
107
97
  Parameters
108
98
  ----------
109
- query: dict, optional
110
- A complete MongoDB query dictionary. This is a positional-only argument.
111
- **kwargs:
112
- Keyword arguments representing field-value pairs for the query.
113
- Values can be single items (str, int) or lists of items for multi-search.
99
+ query : dict, optional
100
+ Complete MongoDB query dictionary. This is a positional-only
101
+ argument.
102
+ **kwargs
103
+ User-friendly field filters that are converted to a MongoDB query.
104
+ Values can be scalars (e.g., ``"sub-01"``) or sequences (translated
105
+ to ``$in`` queries).
114
106
 
115
107
  Returns
116
108
  -------
117
- list:
118
- A list of DB records (string-keyed dictionaries) that match the query.
109
+ list of dict
110
+ DB records that match the query.
119
111
 
120
112
  """
121
113
  final_query: dict[str, Any] | None = None
122
114
 
123
115
  # Accept explicit empty dict {} to mean "match all"
124
116
  raw_query = query if isinstance(query, dict) else None
125
- kwargs_query = self._build_query_from_kwargs(**kwargs) if kwargs else None
117
+ kwargs_query = build_query_from_kwargs(**kwargs) if kwargs else None
126
118
 
127
119
  # Determine presence, treating {} as a valid raw query
128
120
  has_raw = isinstance(raw_query, dict)
@@ -239,59 +231,12 @@ class EEGDash:
239
231
  return record
240
232
 
241
233
  def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
242
- """Build and validate a MongoDB query from user-friendly keyword arguments.
234
+ """Internal helper to build a validated MongoDB query from keyword args.
243
235
 
244
- Improvements:
245
- - Reject None values and empty/whitespace-only strings
246
- - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
247
- - Preserve scalars as exact matches
236
+ This delegates to the module-level builder used across the package and
237
+ is exposed here for testing and convenience.
248
238
  """
249
- # 1. Validate that all provided keys are allowed for querying
250
- unknown_fields = set(kwargs.keys()) - self._ALLOWED_QUERY_FIELDS
251
- if unknown_fields:
252
- raise ValueError(
253
- f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
254
- f"Allowed fields are: {', '.join(sorted(self._ALLOWED_QUERY_FIELDS))}"
255
- )
256
-
257
- # 2. Construct the query dictionary
258
- query = {}
259
- for key, value in kwargs.items():
260
- # None is not a valid constraint
261
- if value is None:
262
- raise ValueError(
263
- f"Received None for query parameter '{key}'. Provide a concrete value."
264
- )
265
-
266
- # Handle list-like values as multi-constraints
267
- if isinstance(value, (list, tuple, set)):
268
- cleaned: list[Any] = []
269
- for item in value:
270
- if item is None:
271
- continue
272
- if isinstance(item, str):
273
- item = item.strip()
274
- if not item:
275
- continue
276
- cleaned.append(item)
277
- # Deduplicate while preserving order
278
- cleaned = list(dict.fromkeys(cleaned))
279
- if not cleaned:
280
- raise ValueError(
281
- f"Received an empty list for query parameter '{key}'. This is not supported."
282
- )
283
- query[key] = {"$in": cleaned}
284
- else:
285
- # Scalars: trim strings and validate
286
- if isinstance(value, str):
287
- value = value.strip()
288
- if not value:
289
- raise ValueError(
290
- f"Received an empty string for query parameter '{key}'."
291
- )
292
- query[key] = value
293
-
294
- return query
239
+ return build_query_from_kwargs(**kwargs)
295
240
 
296
241
  # --- Query merging and conflict detection helpers ---
297
242
  def _extract_simple_constraint(self, query: dict[str, Any], key: str):
@@ -324,8 +269,8 @@ class EEGDash:
324
269
  return
325
270
 
326
271
  # Only consider fields we generally allow; skip meta operators like $and
327
- raw_keys = set(raw_query.keys()) & self._ALLOWED_QUERY_FIELDS
328
- kw_keys = set(kwargs_query.keys()) & self._ALLOWED_QUERY_FIELDS
272
+ raw_keys = set(raw_query.keys()) & ALLOWED_QUERY_FIELDS
273
+ kw_keys = set(kwargs_query.keys()) & ALLOWED_QUERY_FIELDS
329
274
  dup_keys = raw_keys & kw_keys
330
275
  for key in dup_keys:
331
276
  rc = self._extract_simple_constraint(raw_query, key)
@@ -360,44 +305,95 @@ class EEGDash:
360
305
  )
361
306
 
362
307
  def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
363
- """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
308
+ """Load EEG data from an S3 URI into an ``xarray.DataArray``.
309
+
310
+ Preserves the original filename, downloads sidecar files when applicable
311
+ (e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
312
+ MNE's direct readers.
364
313
 
365
314
  Parameters
366
315
  ----------
367
316
  s3path : str
368
- An S3 URI (should start with "s3://") for the file in question.
317
+ An S3 URI (should start with "s3://").
369
318
 
370
319
  Returns
371
320
  -------
372
321
  xr.DataArray
373
- A DataArray containing the EEG data, with dimensions "channel" and "time".
322
+ EEG data with dimensions ``("channel", "time")``.
374
323
 
375
- Example
376
- -------
377
- >>> eegdash = EEGDash()
378
- >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
379
- >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
324
+ Raises
325
+ ------
326
+ ValueError
327
+ If the file extension is unsupported.
380
328
 
381
329
  """
382
- with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
383
- with self.filesystem.open(s3path) as s3_file:
384
- tmp.write(s3_file.read())
385
- tmp_path = tmp.name
386
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
387
- os.unlink(tmp_path)
388
- return eeg_data
330
+ # choose a temp dir so sidecars can be colocated
331
+ with tempfile.TemporaryDirectory() as tmpdir:
332
+ # Derive local filenames from the S3 key to keep base name consistent
333
+ s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
334
+ basename = Path(s3_key).name
335
+ ext = Path(basename).suffix.lower()
336
+ local_main = Path(tmpdir) / basename
337
+
338
+ # Download main file
339
+ with (
340
+ self.filesystem.open(s3path, mode="rb") as fsrc,
341
+ open(local_main, "wb") as fdst,
342
+ ):
343
+ fdst.write(fsrc.read())
344
+
345
+ # Determine and fetch any required sidecars
346
+ sidecars: list[str] = []
347
+ if ext == ".set": # EEGLAB
348
+ sidecars = [".fdt"]
349
+ elif ext == ".vhdr": # BrainVision
350
+ sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
351
+
352
+ for sc_ext in sidecars:
353
+ sc_key = s3_key[: -len(ext)] + sc_ext
354
+ sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
355
+ try:
356
+ # If sidecar exists, download next to the main file
357
+ info = self.filesystem.info(sc_uri)
358
+ if info:
359
+ sc_local = Path(tmpdir) / Path(sc_key).name
360
+ with (
361
+ self.filesystem.open(sc_uri, mode="rb") as fsrc,
362
+ open(sc_local, "wb") as fdst,
363
+ ):
364
+ fdst.write(fsrc.read())
365
+ except Exception:
366
+ # Sidecar not present; skip silently
367
+ pass
368
+
369
+ # Read using appropriate MNE reader
370
+ raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
371
+
372
+ data = raw.get_data()
373
+ fs = raw.info["sfreq"]
374
+ max_time = data.shape[1] / fs
375
+ time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
376
+ channel_names = raw.ch_names
377
+
378
+ return xr.DataArray(
379
+ data=data,
380
+ dims=["channel", "time"],
381
+ coords={"time": time_steps, "channel": channel_names},
382
+ )
389
383
 
390
384
  def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
391
- """Load EEG data from a local file and return it as a xarray DataArray.
385
+ """Load EEG data from a local BIDS-formatted file.
392
386
 
393
387
  Parameters
394
388
  ----------
395
389
  bids_file : str
396
- Path to the BIDS-compliant file on the local filesystem.
390
+ Path to a BIDS-compliant EEG file (e.g., ``*_eeg.edf``, ``*_eeg.bdf``,
391
+ ``*_eeg.vhdr``, ``*_eeg.set``).
397
392
 
398
- Notes
399
- -----
400
- Currently, only non-epoched .set files are supported.
393
+ Returns
394
+ -------
395
+ xr.DataArray
396
+ EEG data with dimensions ``("channel", "time")``.
401
397
 
402
398
  """
403
399
  bids_path = get_bids_path_from_fname(bids_file, verbose=False)
@@ -417,140 +413,25 @@ class EEGDash:
417
413
  )
418
414
  return eeg_xarray
419
415
 
420
- def get_raw_extensions(
421
- self, bids_file: str, bids_dataset: EEGBIDSDataset
422
- ) -> list[str]:
423
- """Helper to find paths to additional "sidecar" files that may be associated
424
- with a given main data file in a BIDS dataset; paths are returned as relative to
425
- the parent dataset path.
426
-
427
- For example, if the input file is a .set file, this will return the relative path
428
- to a corresponding .fdt file (if any).
429
- """
430
- bids_file = Path(bids_file)
431
- extensions = {
432
- ".set": [".set", ".fdt"], # eeglab
433
- ".edf": [".edf"], # european
434
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
435
- ".bdf": [".bdf"], # biosemi
436
- }
437
- return [
438
- str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
439
- for suffix in extensions[bids_file.suffix]
440
- if bids_file.with_suffix(suffix).exists()
441
- ]
442
-
443
- def load_eeg_attrs_from_bids_file(
444
- self, bids_dataset: EEGBIDSDataset, bids_file: str
445
- ) -> dict[str, Any]:
446
- """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
447
-
448
- Attributes are at least the ones defined in data_config attributes (set to None if missing),
449
- but are typically a superset, and include, among others, the paths to relevant
450
- meta-data files needed to load and interpret the file in question.
451
-
452
- Parameters
453
- ----------
454
- bids_dataset : EEGBIDSDataset
455
- The BIDS dataset object containing the file.
456
- bids_file : str
457
- The path to the BIDS file within the dataset.
458
-
459
- Returns
460
- -------
461
- dict:
462
- A dictionary representing the metadata record for the given file. This is the
463
- same format as the records stored in the database.
464
-
465
- """
466
- if bids_file not in bids_dataset.files:
467
- raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
468
-
469
- # Initialize attrs with None values for all expected fields
470
- attrs = {field: None for field in self.config["attributes"].keys()}
471
-
472
- file = Path(bids_file).name
473
- dsnumber = bids_dataset.dataset
474
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
475
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
476
-
477
- # Update with actual values where available
478
- try:
479
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
480
- except Exception as e:
481
- logger.error("Error getting participants_tsv: %s", str(e))
482
- participants_tsv = None
483
-
484
- try:
485
- eeg_json = bids_dataset.eeg_json(bids_file)
486
- except Exception as e:
487
- logger.error("Error getting eeg_json: %s", str(e))
488
- eeg_json = None
489
-
490
- bids_dependencies_files = self.config["bids_dependencies_files"]
491
- bidsdependencies = []
492
- for extension in bids_dependencies_files:
493
- try:
494
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
495
- dep_path = [
496
- str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
497
- ]
498
- bidsdependencies.extend(dep_path)
499
- except Exception:
500
- pass
501
-
502
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
503
-
504
- # Define field extraction functions with error handling
505
- field_extractors = {
506
- "data_name": lambda: f"{bids_dataset.dataset}_{file}",
507
- "dataset": lambda: bids_dataset.dataset,
508
- "bidspath": lambda: openneuro_path,
509
- "subject": lambda: bids_dataset.get_bids_file_attribute(
510
- "subject", bids_file
511
- ),
512
- "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
513
- "session": lambda: bids_dataset.get_bids_file_attribute(
514
- "session", bids_file
515
- ),
516
- "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
517
- "modality": lambda: bids_dataset.get_bids_file_attribute(
518
- "modality", bids_file
519
- ),
520
- "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
521
- "sfreq", bids_file
522
- ),
523
- "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
524
- "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
525
- "participant_tsv": lambda: participants_tsv,
526
- "eeg_json": lambda: eeg_json,
527
- "bidsdependencies": lambda: bidsdependencies,
528
- }
529
-
530
- # Dynamically populate attrs with error handling
531
- for field, extractor in field_extractors.items():
532
- try:
533
- attrs[field] = extractor()
534
- except Exception as e:
535
- logger.error("Error extracting %s : %s", field, str(e))
536
- attrs[field] = None
537
-
538
- return attrs
539
-
540
416
  def add_bids_dataset(
541
417
  self, dataset: str, data_dir: str, overwrite: bool = True
542
418
  ) -> None:
543
- """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
544
- under the given dataset name.
419
+ """Scan a local BIDS dataset and upsert records into MongoDB.
545
420
 
546
421
  Parameters
547
422
  ----------
548
- dataset : str)
549
- The name of the dataset to be added (e.g., "ds002718").
423
+ dataset : str
424
+ Dataset identifier (e.g., ``"ds002718"``).
550
425
  data_dir : str
551
- The path to the BIDS dataset directory.
552
- overwrite : bool
553
- Whether to overwrite/update existing records in the database.
426
+ Path to the local BIDS dataset directory.
427
+ overwrite : bool, default True
428
+ If ``True``, update existing records when encountered; otherwise,
429
+ skip records that already exist.
430
+
431
+ Raises
432
+ ------
433
+ ValueError
434
+ If called on a public client ``(is_public=True)``.
554
435
 
555
436
  """
556
437
  if self.is_public:
@@ -565,7 +446,7 @@ class EEGDash:
565
446
  dataset=dataset,
566
447
  )
567
448
  except Exception as e:
568
- logger.error("Error creating bids dataset %s: $s", dataset, str(e))
449
+ logger.error("Error creating bids dataset %s: %s", dataset, str(e))
569
450
  raise e
570
451
  requests = []
571
452
  for bids_file in bids_dataset.get_files():
@@ -574,15 +455,13 @@ class EEGDash:
574
455
 
575
456
  if self.exist({"data_name": data_id}):
576
457
  if overwrite:
577
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
458
+ eeg_attrs = load_eeg_attrs_from_bids_file(
578
459
  bids_dataset, bids_file
579
460
  )
580
- requests.append(self.update_request(eeg_attrs))
461
+ requests.append(self._update_request(eeg_attrs))
581
462
  else:
582
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
583
- bids_dataset, bids_file
584
- )
585
- requests.append(self.add_request(eeg_attrs))
463
+ eeg_attrs = load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
464
+ requests.append(self._add_request(eeg_attrs))
586
465
  except Exception as e:
587
466
  logger.error("Error adding record %s", bids_file)
588
467
  logger.error(str(e))
@@ -598,22 +477,22 @@ class EEGDash:
598
477
  logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
599
478
 
600
479
  def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
601
- """Retrieve a list of EEG data arrays that match the given query. See also
602
- the `find()` method for details on the query format.
480
+ """Download and return EEG data arrays for records matching a query.
603
481
 
604
482
  Parameters
605
483
  ----------
606
484
  query : dict
607
- A dictionary that specifies the query to be executed; this is a reference
608
- document that is used to match records in the MongoDB collection.
485
+ MongoDB query used to select records.
609
486
 
610
487
  Returns
611
488
  -------
612
- A list of xarray DataArray objects containing the EEG data for each matching record.
489
+ list of xr.DataArray
490
+ EEG data for each matching record, with dimensions ``("channel", "time")``.
613
491
 
614
492
  Notes
615
493
  -----
616
- Retrieval is done in parallel, and the downloaded data are not cached locally.
494
+ Retrieval runs in parallel. Downloaded files are read and discarded
495
+ (no on-disk caching here).
617
496
 
618
497
  """
619
498
  sessions = self.find(query)
@@ -623,12 +502,40 @@ class EEGDash:
623
502
  results = Parallel(
624
503
  n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
625
504
  )(
626
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
505
+ delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
627
506
  for session in sessions
628
507
  )
629
508
  return results
630
509
 
631
- def add_request(self, record: dict):
510
+ def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
511
+ """Build an S3 URI from a DB record or a relative path.
512
+
513
+ Parameters
514
+ ----------
515
+ record : dict or str
516
+ Either a DB record containing a ``'bidspath'`` key, or a relative
517
+ path string under the OpenNeuro bucket.
518
+
519
+ Returns
520
+ -------
521
+ str
522
+ Fully qualified S3 URI.
523
+
524
+ Raises
525
+ ------
526
+ ValueError
527
+ If a mapping is provided but ``'bidspath'`` is missing.
528
+
529
+ """
530
+ if isinstance(record, str):
531
+ rel = record
532
+ else:
533
+ rel = record.get("bidspath")
534
+ if not rel:
535
+ raise ValueError("Record missing 'bidspath' for S3 path resolution")
536
+ return f"s3://openneuro.org/{rel}"
537
+
538
+ def _add_request(self, record: dict):
632
539
  """Internal helper method to create a MongoDB insertion request for a record."""
633
540
  return InsertOne(record)
634
541
 
@@ -642,12 +549,19 @@ class EEGDash:
642
549
  except:
643
550
  logger.error("Error adding record: %s ", record["data_name"])
644
551
 
645
- def update_request(self, record: dict):
552
+ def _update_request(self, record: dict):
646
553
  """Internal helper method to create a MongoDB update request for a record."""
647
554
  return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
648
555
 
649
556
  def update(self, record: dict):
650
- """Update a single record in the MongoDB collection."""
557
+ """Update a single record in the MongoDB collection.
558
+
559
+ Parameters
560
+ ----------
561
+ record : dict
562
+ Record content to set at the matching ``data_name``.
563
+
564
+ """
651
565
  try:
652
566
  self.__collection.update_one(
653
567
  {"data_name": record["data_name"]}, {"$set": record}
@@ -655,15 +569,33 @@ class EEGDash:
655
569
  except: # silent failure
656
570
  logger.error("Error updating record: %s", record["data_name"])
657
571
 
572
+ def exists(self, query: dict[str, Any]) -> bool:
573
+ """Alias for :meth:`exist` provided for API clarity."""
574
+ return self.exist(query)
575
+
658
576
  def remove_field(self, record, field):
659
- """Remove a specific field from a record in the MongoDB collection."""
577
+ """Remove a specific field from a record in the MongoDB collection.
578
+
579
+ Parameters
580
+ ----------
581
+ record : dict
582
+ Record identifying object with ``data_name``.
583
+ field : str
584
+ Field name to remove.
585
+
586
+ """
660
587
  self.__collection.update_one(
661
588
  {"data_name": record["data_name"]}, {"$unset": {field: 1}}
662
589
  )
663
590
 
664
591
  def remove_field_from_db(self, field):
665
- """Removed all occurrences of a specific field from all records in the MongoDB
666
- collection. WARNING: this operation is destructive and should be used with caution.
592
+ """Remove a field from all records (destructive).
593
+
594
+ Parameters
595
+ ----------
596
+ field : str
597
+ Field name to remove from every document.
598
+
667
599
  """
668
600
  self.__collection.update_many({}, {"$unset": {field: 1}})
669
601
 
@@ -673,11 +605,13 @@ class EEGDash:
673
605
  return self.__collection
674
606
 
675
607
  def close(self):
676
- """Close the MongoDB client connection.
608
+ """Backward-compatibility no-op; connections are managed globally.
609
+
610
+ Notes
611
+ -----
612
+ Connections are managed by :class:`MongoConnectionManager`. Use
613
+ :meth:`close_all_connections` to explicitly close all clients.
677
614
 
678
- Note: Since MongoDB clients are now managed by a singleton,
679
- this method no longer closes connections. Use close_all_connections()
680
- class method to close all connections if needed.
681
615
  """
682
616
  # Individual instances no longer close the shared client
683
617
  pass
@@ -688,7 +622,7 @@ class EEGDash:
688
622
  MongoConnectionManager.close_all()
689
623
 
690
624
  def __del__(self):
691
- """Ensure connection is closed when object is deleted."""
625
+ """Destructor; no explicit action needed due to global connection manager."""
692
626
  # No longer needed since we're using singleton pattern
693
627
  pass
694
628
 
@@ -708,17 +642,16 @@ class EEGDashDataset(BaseConcatDataset):
708
642
  "sex",
709
643
  ],
710
644
  s3_bucket: str | None = None,
711
- eeg_dash_instance=None,
712
645
  records: list[dict] | None = None,
713
- offline_mode: bool = False,
646
+ download: bool = True,
714
647
  n_jobs: int = -1,
648
+ eeg_dash_instance: EEGDash | None = None,
715
649
  **kwargs,
716
650
  ):
717
651
  """Create a new EEGDashDataset from a given query or local BIDS dataset directory
718
652
  and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
719
653
  instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
720
654
 
721
-
722
655
  Querying Examples:
723
656
  ------------------
724
657
  # Find by single subject
@@ -734,57 +667,91 @@ class EEGDashDataset(BaseConcatDataset):
734
667
 
735
668
  Parameters
736
669
  ----------
670
+ cache_dir : str | Path
671
+ Directory where data are cached locally. If not specified, a default
672
+ cache directory under the user cache is used.
737
673
  query : dict | None
738
- A raw MongoDB query dictionary. If provided, keyword arguments for filtering are ignored.
739
- **kwargs : dict
740
- Keyword arguments for filtering (e.g., `subject="X"`, `task=["T1", "T2"]`) and/or
741
- arguments to be passed to the EEGDashBaseDataset constructor (e.g., `subject=...`).
742
- cache_dir : str
743
- A directory where the dataset will be cached locally.
744
- data_dir : str | None
745
- Optionally a string specifying a local BIDS dataset directory from which to load the EEG data files. Exactly one
746
- of query or data_dir must be provided.
747
- dataset : str | None
748
- If data_dir is given, a name for the dataset to be loaded.
674
+ Raw MongoDB query to filter records. If provided, it is merged with
675
+ keyword filtering arguments (see ``**kwargs``) using logical AND.
676
+ You must provide at least a ``dataset`` (either in ``query`` or
677
+ as a keyword argument). Only fields in ``ALLOWED_QUERY_FIELDS`` are
678
+ considered for filtering.
749
679
  description_fields : list[str]
750
- A list of fields to be extracted from the dataset records
751
- and included in the returned data description(s). Examples are typical
752
- subject metadata fields such as "subject", "session", "run", "task", etc.;
753
- see also data_config.description_fields for the default set of fields.
680
+ Fields to extract from each record and include in dataset descriptions
681
+ (e.g., "subject", "session", "run", "task").
754
682
  s3_bucket : str | None
755
- An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
756
- default OpenNeuro bucket for loading data files
683
+ Optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
684
+ default OpenNeuro bucket when downloading data files.
757
685
  records : list[dict] | None
758
- Optional list of pre-fetched metadata records. If provided, the dataset is
759
- constructed directly from these records without querying MongoDB.
760
- offline_mode : bool
761
- If True, do not attempt to query MongoDB at all. This is useful if you want to
762
- work with a local cache only, or if you are offline.
686
+ Pre-fetched metadata records. If provided, the dataset is constructed
687
+ directly from these records and no MongoDB query is performed.
688
+ download : bool, default True
689
+ If False, load from local BIDS files only. Local data are expected
690
+ under ``cache_dir / dataset``; no DB or S3 access is attempted.
763
691
  n_jobs : int
764
- The number of jobs to run in parallel (default is -1, meaning using all processors).
765
- kwargs : dict
766
- Additional keyword arguments to be passed to the EEGDashBaseDataset
767
- constructor.
692
+ Number of parallel jobs to use where applicable (-1 uses all cores).
693
+ eeg_dash_instance : EEGDash | None
694
+ Optional existing EEGDash client to reuse for DB queries. If None,
695
+ a new client is created on demand, not used in the case of no download.
696
+ **kwargs : dict
697
+ Additional keyword arguments serving two purposes:
698
+ - Filtering: any keys present in ``ALLOWED_QUERY_FIELDS`` are treated
699
+ as query filters (e.g., ``dataset``, ``subject``, ``task``, ...).
700
+ - Dataset options: remaining keys are forwarded to the
701
+ ``EEGDashBaseDataset`` constructor.
768
702
 
769
703
  """
704
+ # Parameters that don't need validation
705
+ _suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
706
+ self.s3_bucket = s3_bucket
707
+ self.records = records
708
+ self.download = download
709
+ self.n_jobs = n_jobs
710
+ self.eeg_dash_instance = eeg_dash_instance or EEGDash()
711
+
770
712
  self.cache_dir = Path(cache_dir or platformdirs.user_cache_dir("EEGDash"))
713
+
771
714
  if not self.cache_dir.exists():
772
715
  warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
773
716
  self.cache_dir.mkdir(exist_ok=True, parents=True)
774
- self.s3_bucket = s3_bucket
775
- self.eeg_dash = eeg_dash_instance
776
717
 
777
718
  # Separate query kwargs from other kwargs passed to the BaseDataset constructor
778
719
  self.query = query or {}
779
720
  self.query.update(
780
- {k: v for k, v in kwargs.items() if k in EEGDash._ALLOWED_QUERY_FIELDS}
721
+ {k: v for k, v in kwargs.items() if k in ALLOWED_QUERY_FIELDS}
781
722
  )
782
723
  base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in self.query}
783
724
  if "dataset" not in self.query:
784
- raise ValueError("You must provide a 'dataset' argument")
785
-
786
- self.data_dir = self.cache_dir / self.query["dataset"]
787
- if self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values():
725
+ # If explicit records are provided, infer dataset from records
726
+ if isinstance(records, list) and records and isinstance(records[0], dict):
727
+ inferred = records[0].get("dataset")
728
+ if inferred:
729
+ self.query["dataset"] = inferred
730
+ else:
731
+ raise ValueError("You must provide a 'dataset' argument")
732
+ else:
733
+ raise ValueError("You must provide a 'dataset' argument")
734
+
735
+ # Decide on a dataset subfolder name for cache isolation. If using
736
+ # challenge/preprocessed buckets (e.g., BDF, mini subsets), append
737
+ # informative suffixes to avoid overlapping with the original dataset.
738
+ dataset_folder = self.query["dataset"]
739
+ if self.s3_bucket:
740
+ suffixes: list[str] = []
741
+ bucket_lower = str(self.s3_bucket).lower()
742
+ if "bdf" in bucket_lower:
743
+ suffixes.append("bdf")
744
+ if "mini" in bucket_lower:
745
+ suffixes.append("mini")
746
+ if suffixes:
747
+ dataset_folder = f"{dataset_folder}-{'-'.join(suffixes)}"
748
+
749
+ self.data_dir = self.cache_dir / dataset_folder
750
+
751
+ if (
752
+ not _suppress_comp_warning
753
+ and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
754
+ ):
788
755
  warn(
789
756
  "If you are not participating in the competition, you can ignore this warning!"
790
757
  "\n\n"
@@ -800,60 +767,167 @@ class EEGDashDataset(BaseConcatDataset):
800
767
  UserWarning,
801
768
  module="eegdash",
802
769
  )
803
- _owns_client = False
804
- if self.eeg_dash is None and records is None:
805
- self.eeg_dash = EEGDash()
806
- _owns_client = True
807
-
808
- try:
809
- if records is not None:
810
- self.records = records
811
- datasets = [
812
- EEGDashBaseDataset(
813
- record,
814
- self.cache_dir,
815
- self.s3_bucket,
816
- **base_dataset_kwargs,
817
- )
818
- for record in self.records
819
- ]
820
- elif offline_mode: # only assume local data is complete if in offline mode
821
- if self.data_dir.exists():
822
- # This path loads from a local directory and is not affected by DB query logic
823
- datasets = self.load_bids_daxtaset(
824
- dataset=self.query["dataset"],
825
- data_dir=self.data_dir,
826
- description_fields=description_fields,
827
- s3_bucket=s3_bucket,
828
- n_jobs=n_jobs,
829
- **base_dataset_kwargs,
830
- )
831
- else:
832
- raise ValueError(
833
- f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
834
- )
835
- elif self.query:
836
- # This is the DB query path that we are improving
837
- datasets = self._find_datasets(
838
- query=self.eeg_dash._build_query_from_kwargs(**self.query),
839
- description_fields=description_fields,
840
- base_dataset_kwargs=base_dataset_kwargs,
841
- )
842
- # We only need filesystem if we need to access S3
843
- self.filesystem = S3FileSystem(
844
- anon=True, client_kwargs={"region_name": "us-east-2"}
770
+ if records is not None:
771
+ self.records = records
772
+ datasets = [
773
+ EEGDashBaseDataset(
774
+ record,
775
+ self.cache_dir,
776
+ self.s3_bucket,
777
+ **base_dataset_kwargs,
845
778
  )
846
- else:
779
+ for record in self.records
780
+ ]
781
+ elif not download: # only assume local data is complete if not downloading
782
+ if not self.data_dir.exists():
847
783
  raise ValueError(
848
- "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
784
+ f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
785
+ )
786
+ records = self._find_local_bids_records(self.data_dir, self.query)
787
+ datasets = [
788
+ EEGDashBaseDataset(
789
+ record=record,
790
+ cache_dir=self.cache_dir,
791
+ s3_bucket=self.s3_bucket,
792
+ description={
793
+ k: record.get(k)
794
+ for k in description_fields
795
+ if record.get(k) is not None
796
+ },
797
+ **base_dataset_kwargs,
849
798
  )
850
- finally:
851
- if _owns_client and self.eeg_dash is not None:
852
- self.eeg_dash.close()
799
+ for record in records
800
+ ]
801
+ elif self.query:
802
+ # This is the DB query path that we are improving
803
+ datasets = self._find_datasets(
804
+ query=build_query_from_kwargs(**self.query),
805
+ description_fields=description_fields,
806
+ base_dataset_kwargs=base_dataset_kwargs,
807
+ )
808
+ # We only need filesystem if we need to access S3
809
+ self.filesystem = S3FileSystem(
810
+ anon=True, client_kwargs={"region_name": "us-east-2"}
811
+ )
812
+ else:
813
+ raise ValueError(
814
+ "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
815
+ )
853
816
 
854
817
  super().__init__(datasets)
855
818
 
856
- def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
819
+ def _find_local_bids_records(
820
+ self, dataset_root: Path, filters: dict[str, Any]
821
+ ) -> list[dict]:
822
+ """Discover local BIDS EEG files and build minimal records.
823
+
824
+ This helper enumerates EEG recordings under ``dataset_root`` via
825
+ ``mne_bids.find_matching_paths`` and applies entity filters to produce a
826
+ list of records suitable for ``EEGDashBaseDataset``. No network access
827
+ is performed and files are not read.
828
+
829
+ Parameters
830
+ ----------
831
+ dataset_root : Path
832
+ Local dataset directory. May be the plain dataset folder (e.g.,
833
+ ``ds005509``) or a suffixed cache variant (e.g.,
834
+ ``ds005509-bdf-mini``).
835
+ filters : dict of {str, Any}
836
+ Query filters. Must include ``'dataset'`` with the dataset id (without
837
+ local suffixes). May include BIDS entities ``'subject'``,
838
+ ``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar
839
+ or a sequence of scalars.
840
+
841
+ Returns
842
+ -------
843
+ records : list of dict
844
+ One record per matched EEG file with at least:
845
+
846
+ - ``'data_name'``
847
+ - ``'dataset'`` (dataset id, without suffixes)
848
+ - ``'bidspath'`` (normalized to start with the dataset id)
849
+ - ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None)
850
+ - ``'bidsdependencies'`` (empty list)
851
+ - ``'modality'`` (``"eeg"``)
852
+ - ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal
853
+ defaults for offline usage)
854
+
855
+ Notes
856
+ -----
857
+ - Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
858
+ - ``bidspath`` is constructed as
859
+ ``<dataset_id> / <relative_path_from_dataset_root>`` to ensure the
860
+ first path component is the dataset id (without local cache suffixes).
861
+ - Minimal defaults are set for ``sampling_frequency``, ``nchans``, and
862
+ ``ntimes`` to satisfy dataset length requirements offline.
863
+
864
+ """
865
+ dataset_id = filters["dataset"]
866
+ arg_map = {
867
+ "subjects": "subject",
868
+ "sessions": "session",
869
+ "tasks": "task",
870
+ "runs": "run",
871
+ }
872
+ matching_args: dict[str, list[str]] = {}
873
+ for finder_key, entity_key in arg_map.items():
874
+ entity_val = filters.get(entity_key)
875
+ if entity_val is None:
876
+ continue
877
+ if isinstance(entity_val, (list, tuple, set)):
878
+ entity_vals = list(entity_val)
879
+ if not entity_vals:
880
+ continue
881
+ matching_args[finder_key] = entity_vals
882
+ else:
883
+ matching_args[finder_key] = [entity_val]
884
+
885
+ paths = find_matching_paths(
886
+ root=str(dataset_root),
887
+ datatypes=["eeg"],
888
+ suffixes=["eeg"],
889
+ ignore_json=True,
890
+ **matching_args,
891
+ )
892
+
893
+ records: list[dict] = []
894
+ seen_files: set[str] = set()
895
+
896
+ for bids_path in paths:
897
+ fpath = str(Path(bids_path.fpath).resolve())
898
+ if fpath in seen_files:
899
+ continue
900
+ seen_files.add(fpath)
901
+
902
+ # Build bidspath as dataset_id / relative_path_from_dataset_root (POSIX)
903
+ rel_from_root = (
904
+ Path(bids_path.fpath)
905
+ .resolve()
906
+ .relative_to(Path(bids_path.root).resolve())
907
+ )
908
+ bidspath = f"{dataset_id}/{rel_from_root.as_posix()}"
909
+
910
+ rec = {
911
+ "data_name": f"{dataset_id}_{Path(bids_path.fpath).name}",
912
+ "dataset": dataset_id,
913
+ "bidspath": bidspath,
914
+ "subject": (bids_path.subject or None),
915
+ "session": (bids_path.session or None),
916
+ "task": (bids_path.task or None),
917
+ "run": (bids_path.run or None),
918
+ # minimal fields to satisfy BaseDataset
919
+ "bidsdependencies": [], # not needed to just run.
920
+ "modality": "eeg",
921
+ # this information is from eegdash schema but not available locally
922
+ "sampling_frequency": 1.0,
923
+ "nchans": 1,
924
+ "ntimes": 1,
925
+ }
926
+ records.append(rec)
927
+
928
+ return records
929
+
930
+ def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
857
931
  """Helper to recursively search for a key in a nested dictionary structure; returns
858
932
  the value associated with the first occurrence of the key, or None if not found.
859
933
  """
@@ -861,7 +935,7 @@ class EEGDashDataset(BaseConcatDataset):
861
935
  if target_key in data:
862
936
  return data[target_key]
863
937
  for value in data.values():
864
- result = self.find_key_in_nested_dict(value, target_key)
938
+ result = self._find_key_in_nested_dict(value, target_key)
865
939
  if result is not None:
866
940
  return result
867
941
  return None
@@ -892,13 +966,12 @@ class EEGDashDataset(BaseConcatDataset):
892
966
 
893
967
  """
894
968
  datasets: list[EEGDashBaseDataset] = []
895
-
896
- self.records = self.eeg_dash.find(query)
969
+ self.records = self.eeg_dash_instance.find(query)
897
970
 
898
971
  for record in self.records:
899
972
  description = {}
900
973
  for field in description_fields:
901
- value = self.find_key_in_nested_dict(record, field)
974
+ value = self._find_key_in_nested_dict(record, field)
902
975
  if value is not None:
903
976
  description[field] = value
904
977
  datasets.append(
@@ -911,69 +984,3 @@ class EEGDashDataset(BaseConcatDataset):
911
984
  )
912
985
  )
913
986
  return datasets
914
-
915
- def load_bids_dataset(
916
- self,
917
- dataset: str,
918
- data_dir: str | Path,
919
- description_fields: list[str],
920
- s3_bucket: str | None = None,
921
- n_jobs: int = -1,
922
- **kwargs,
923
- ):
924
- """Helper method to load a single local BIDS dataset and return it as a list of
925
- EEGDashBaseDatasets (one for each recording in the dataset).
926
-
927
- Parameters
928
- ----------
929
- dataset : str
930
- A name for the dataset to be loaded (e.g., "ds002718").
931
- data_dir : str
932
- The path to the local BIDS dataset directory.
933
- description_fields : list[str]
934
- A list of fields to be extracted from the dataset records
935
- and included in the returned dataset description(s).
936
- s3_bucket : str | None
937
- The S3 bucket to upload the dataset files to (if any).
938
- n_jobs : int
939
- The number of jobs to run in parallel (default is -1, meaning using all processors).
940
-
941
- """
942
- bids_dataset = EEGBIDSDataset(
943
- data_dir=data_dir,
944
- dataset=dataset,
945
- )
946
- datasets = Parallel(n_jobs=n_jobs, prefer="threads", verbose=1)(
947
- delayed(self.get_base_dataset_from_bids_file)(
948
- bids_dataset=bids_dataset,
949
- bids_file=bids_file,
950
- s3_bucket=s3_bucket,
951
- description_fields=description_fields,
952
- **kwargs,
953
- )
954
- for bids_file in bids_dataset.get_files()
955
- )
956
- return datasets
957
-
958
- def get_base_dataset_from_bids_file(
959
- self,
960
- bids_dataset: "EEGBIDSDataset",
961
- bids_file: str,
962
- s3_bucket: str | None,
963
- description_fields: list[str],
964
- **kwargs,
965
- ) -> "EEGDashBaseDataset":
966
- """Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
967
- record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
968
- description = {}
969
- for field in description_fields:
970
- value = self.find_key_in_nested_dict(record, field)
971
- if value is not None:
972
- description[field] = value
973
- return EEGDashBaseDataset(
974
- record,
975
- self.cache_dir,
976
- s3_bucket,
977
- description=description,
978
- **kwargs,
979
- )