eegdash 0.3.7.dev177024734__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/api.py CHANGED
@@ -3,69 +3,68 @@ import os
3
3
  import tempfile
4
4
  from pathlib import Path
5
5
  from typing import Any, Mapping
6
+ from urllib.parse import urlsplit
6
7
 
7
8
  import mne
8
9
  import numpy as np
9
- import platformdirs
10
10
  import xarray as xr
11
+ from docstring_inheritance import NumpyDocstringInheritanceInitMeta
11
12
  from dotenv import load_dotenv
12
13
  from joblib import Parallel, delayed
13
14
  from mne.utils import warn
14
- from mne_bids import get_bids_path_from_fname, read_raw_bids
15
+ from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
15
16
  from pymongo import InsertOne, UpdateOne
16
17
  from s3fs import S3FileSystem
17
18
 
18
19
  from braindecode.datasets import BaseConcatDataset
19
20
 
20
- from .const import RELEASE_TO_OPENNEURO_DATASET_MAP
21
- from .data_config import config as data_config
22
- from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
21
+ from .bids_eeg_metadata import (
22
+ build_query_from_kwargs,
23
+ load_eeg_attrs_from_bids_file,
24
+ merge_participants_fields,
25
+ normalize_key,
26
+ )
27
+ from .const import (
28
+ ALLOWED_QUERY_FIELDS,
29
+ RELEASE_TO_OPENNEURO_DATASET_MAP,
30
+ )
31
+ from .const import config as data_config
32
+ from .data_utils import (
33
+ EEGBIDSDataset,
34
+ EEGDashBaseDataset,
35
+ )
23
36
  from .mongodb import MongoConnectionManager
37
+ from .paths import get_default_cache_dir
24
38
 
25
39
  logger = logging.getLogger("eegdash")
26
40
 
27
41
 
28
42
  class EEGDash:
29
- """A high-level interface to the EEGDash database.
43
+ """High-level interface to the EEGDash metadata database.
30
44
 
31
- This class is primarily used to interact with the metadata records stored in the
32
- EEGDash database (or a private instance of it), allowing users to find, add, and
33
- update EEG data records.
34
-
35
- While this class provides basic support for loading EEG data, please see
36
- the EEGDashDataset class for a more complete way to retrieve and work with full
37
- datasets.
45
+ Provides methods to query, insert, and update metadata records stored in the
46
+ EEGDash MongoDB database (public or private). Also includes utilities to load
47
+ EEG data from S3 for matched records.
38
48
 
49
+ For working with collections of
50
+ recordings as PyTorch datasets, prefer :class:`EEGDashDataset`.
39
51
  """
40
52
 
41
- _ALLOWED_QUERY_FIELDS = {
42
- "data_name",
43
- "dataset",
44
- "subject",
45
- "task",
46
- "session",
47
- "run",
48
- "modality",
49
- "sampling_frequency",
50
- "nchans",
51
- "ntimes",
52
- }
53
-
54
53
  def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
55
- """Create new instance of the EEGDash Database client.
54
+ """Create a new EEGDash client.
56
55
 
57
56
  Parameters
58
57
  ----------
59
- is_public: bool
60
- Whether to connect to the public MongoDB database; if False, connect to a
61
- private database instance as per the DB_CONNECTION_STRING env variable
62
- (or .env file entry).
63
- is_staging: bool
64
- If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
65
- production database ("eegdash").
66
-
67
- Example
68
- -------
58
+ is_public : bool, default True
59
+ Connect to the public MongoDB database. If ``False``, connect to a
60
+ private database instance using the ``DB_CONNECTION_STRING`` environment
61
+ variable (or value from a ``.env`` file).
62
+ is_staging : bool, default False
63
+ If ``True``, use the staging database (``eegdashstaging``); otherwise
64
+ use the production database (``eegdash``).
65
+
66
+ Examples
67
+ --------
69
68
  >>> eegdash = EEGDash()
70
69
 
71
70
  """
@@ -106,23 +105,25 @@ class EEGDash:
106
105
 
107
106
  Parameters
108
107
  ----------
109
- query: dict, optional
110
- A complete MongoDB query dictionary. This is a positional-only argument.
111
- **kwargs:
112
- Keyword arguments representing field-value pairs for the query.
113
- Values can be single items (str, int) or lists of items for multi-search.
108
+ query : dict, optional
109
+ Complete MongoDB query dictionary. This is a positional-only
110
+ argument.
111
+ **kwargs
112
+ User-friendly field filters that are converted to a MongoDB query.
113
+ Values can be scalars (e.g., ``"sub-01"``) or sequences (translated
114
+ to ``$in`` queries).
114
115
 
115
116
  Returns
116
117
  -------
117
- list:
118
- A list of DB records (string-keyed dictionaries) that match the query.
118
+ list of dict
119
+ DB records that match the query.
119
120
 
120
121
  """
121
122
  final_query: dict[str, Any] | None = None
122
123
 
123
124
  # Accept explicit empty dict {} to mean "match all"
124
125
  raw_query = query if isinstance(query, dict) else None
125
- kwargs_query = self._build_query_from_kwargs(**kwargs) if kwargs else None
126
+ kwargs_query = build_query_from_kwargs(**kwargs) if kwargs else None
126
127
 
127
128
  # Determine presence, treating {} as a valid raw query
128
129
  has_raw = isinstance(raw_query, dict)
@@ -239,59 +240,12 @@ class EEGDash:
239
240
  return record
240
241
 
241
242
  def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
242
- """Build and validate a MongoDB query from user-friendly keyword arguments.
243
+ """Internal helper to build a validated MongoDB query from keyword args.
243
244
 
244
- Improvements:
245
- - Reject None values and empty/whitespace-only strings
246
- - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
247
- - Preserve scalars as exact matches
245
+ This delegates to the module-level builder used across the package and
246
+ is exposed here for testing and convenience.
248
247
  """
249
- # 1. Validate that all provided keys are allowed for querying
250
- unknown_fields = set(kwargs.keys()) - self._ALLOWED_QUERY_FIELDS
251
- if unknown_fields:
252
- raise ValueError(
253
- f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
254
- f"Allowed fields are: {', '.join(sorted(self._ALLOWED_QUERY_FIELDS))}"
255
- )
256
-
257
- # 2. Construct the query dictionary
258
- query = {}
259
- for key, value in kwargs.items():
260
- # None is not a valid constraint
261
- if value is None:
262
- raise ValueError(
263
- f"Received None for query parameter '{key}'. Provide a concrete value."
264
- )
265
-
266
- # Handle list-like values as multi-constraints
267
- if isinstance(value, (list, tuple, set)):
268
- cleaned: list[Any] = []
269
- for item in value:
270
- if item is None:
271
- continue
272
- if isinstance(item, str):
273
- item = item.strip()
274
- if not item:
275
- continue
276
- cleaned.append(item)
277
- # Deduplicate while preserving order
278
- cleaned = list(dict.fromkeys(cleaned))
279
- if not cleaned:
280
- raise ValueError(
281
- f"Received an empty list for query parameter '{key}'. This is not supported."
282
- )
283
- query[key] = {"$in": cleaned}
284
- else:
285
- # Scalars: trim strings and validate
286
- if isinstance(value, str):
287
- value = value.strip()
288
- if not value:
289
- raise ValueError(
290
- f"Received an empty string for query parameter '{key}'."
291
- )
292
- query[key] = value
293
-
294
- return query
248
+ return build_query_from_kwargs(**kwargs)
295
249
 
296
250
  # --- Query merging and conflict detection helpers ---
297
251
  def _extract_simple_constraint(self, query: dict[str, Any], key: str):
@@ -324,8 +278,8 @@ class EEGDash:
324
278
  return
325
279
 
326
280
  # Only consider fields we generally allow; skip meta operators like $and
327
- raw_keys = set(raw_query.keys()) & self._ALLOWED_QUERY_FIELDS
328
- kw_keys = set(kwargs_query.keys()) & self._ALLOWED_QUERY_FIELDS
281
+ raw_keys = set(raw_query.keys()) & ALLOWED_QUERY_FIELDS
282
+ kw_keys = set(kwargs_query.keys()) & ALLOWED_QUERY_FIELDS
329
283
  dup_keys = raw_keys & kw_keys
330
284
  for key in dup_keys:
331
285
  rc = self._extract_simple_constraint(raw_query, key)
@@ -360,44 +314,95 @@ class EEGDash:
360
314
  )
361
315
 
362
316
  def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
363
- """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
317
+ """Load EEG data from an S3 URI into an ``xarray.DataArray``.
318
+
319
+ Preserves the original filename, downloads sidecar files when applicable
320
+ (e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
321
+ MNE's direct readers.
364
322
 
365
323
  Parameters
366
324
  ----------
367
325
  s3path : str
368
- An S3 URI (should start with "s3://") for the file in question.
326
+ An S3 URI (should start with "s3://").
369
327
 
370
328
  Returns
371
329
  -------
372
330
  xr.DataArray
373
- A DataArray containing the EEG data, with dimensions "channel" and "time".
331
+ EEG data with dimensions ``("channel", "time")``.
374
332
 
375
- Example
376
- -------
377
- >>> eegdash = EEGDash()
378
- >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
379
- >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
333
+ Raises
334
+ ------
335
+ ValueError
336
+ If the file extension is unsupported.
380
337
 
381
338
  """
382
- with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
383
- with self.filesystem.open(s3path) as s3_file:
384
- tmp.write(s3_file.read())
385
- tmp_path = tmp.name
386
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
387
- os.unlink(tmp_path)
388
- return eeg_data
339
+ # choose a temp dir so sidecars can be colocated
340
+ with tempfile.TemporaryDirectory() as tmpdir:
341
+ # Derive local filenames from the S3 key to keep base name consistent
342
+ s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
343
+ basename = Path(s3_key).name
344
+ ext = Path(basename).suffix.lower()
345
+ local_main = Path(tmpdir) / basename
346
+
347
+ # Download main file
348
+ with (
349
+ self.filesystem.open(s3path, mode="rb") as fsrc,
350
+ open(local_main, "wb") as fdst,
351
+ ):
352
+ fdst.write(fsrc.read())
353
+
354
+ # Determine and fetch any required sidecars
355
+ sidecars: list[str] = []
356
+ if ext == ".set": # EEGLAB
357
+ sidecars = [".fdt"]
358
+ elif ext == ".vhdr": # BrainVision
359
+ sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
360
+
361
+ for sc_ext in sidecars:
362
+ sc_key = s3_key[: -len(ext)] + sc_ext
363
+ sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
364
+ try:
365
+ # If sidecar exists, download next to the main file
366
+ info = self.filesystem.info(sc_uri)
367
+ if info:
368
+ sc_local = Path(tmpdir) / Path(sc_key).name
369
+ with (
370
+ self.filesystem.open(sc_uri, mode="rb") as fsrc,
371
+ open(sc_local, "wb") as fdst,
372
+ ):
373
+ fdst.write(fsrc.read())
374
+ except Exception:
375
+ # Sidecar not present; skip silently
376
+ pass
377
+
378
+ # Read using appropriate MNE reader
379
+ raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
380
+
381
+ data = raw.get_data()
382
+ fs = raw.info["sfreq"]
383
+ max_time = data.shape[1] / fs
384
+ time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
385
+ channel_names = raw.ch_names
386
+
387
+ return xr.DataArray(
388
+ data=data,
389
+ dims=["channel", "time"],
390
+ coords={"time": time_steps, "channel": channel_names},
391
+ )
389
392
 
390
393
  def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
391
- """Load EEG data from a local file and return it as a xarray DataArray.
394
+ """Load EEG data from a local BIDS-formatted file.
392
395
 
393
396
  Parameters
394
397
  ----------
395
398
  bids_file : str
396
- Path to the BIDS-compliant file on the local filesystem.
399
+ Path to a BIDS-compliant EEG file (e.g., ``*_eeg.edf``, ``*_eeg.bdf``,
400
+ ``*_eeg.vhdr``, ``*_eeg.set``).
397
401
 
398
- Notes
399
- -----
400
- Currently, only non-epoched .set files are supported.
402
+ Returns
403
+ -------
404
+ xr.DataArray
405
+ EEG data with dimensions ``("channel", "time")``.
401
406
 
402
407
  """
403
408
  bids_path = get_bids_path_from_fname(bids_file, verbose=False)
@@ -417,140 +422,25 @@ class EEGDash:
417
422
  )
418
423
  return eeg_xarray
419
424
 
420
- def get_raw_extensions(
421
- self, bids_file: str, bids_dataset: EEGBIDSDataset
422
- ) -> list[str]:
423
- """Helper to find paths to additional "sidecar" files that may be associated
424
- with a given main data file in a BIDS dataset; paths are returned as relative to
425
- the parent dataset path.
426
-
427
- For example, if the input file is a .set file, this will return the relative path
428
- to a corresponding .fdt file (if any).
429
- """
430
- bids_file = Path(bids_file)
431
- extensions = {
432
- ".set": [".set", ".fdt"], # eeglab
433
- ".edf": [".edf"], # european
434
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
435
- ".bdf": [".bdf"], # biosemi
436
- }
437
- return [
438
- str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
439
- for suffix in extensions[bids_file.suffix]
440
- if bids_file.with_suffix(suffix).exists()
441
- ]
442
-
443
- def load_eeg_attrs_from_bids_file(
444
- self, bids_dataset: EEGBIDSDataset, bids_file: str
445
- ) -> dict[str, Any]:
446
- """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
447
-
448
- Attributes are at least the ones defined in data_config attributes (set to None if missing),
449
- but are typically a superset, and include, among others, the paths to relevant
450
- meta-data files needed to load and interpret the file in question.
451
-
452
- Parameters
453
- ----------
454
- bids_dataset : EEGBIDSDataset
455
- The BIDS dataset object containing the file.
456
- bids_file : str
457
- The path to the BIDS file within the dataset.
458
-
459
- Returns
460
- -------
461
- dict:
462
- A dictionary representing the metadata record for the given file. This is the
463
- same format as the records stored in the database.
464
-
465
- """
466
- if bids_file not in bids_dataset.files:
467
- raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
468
-
469
- # Initialize attrs with None values for all expected fields
470
- attrs = {field: None for field in self.config["attributes"].keys()}
471
-
472
- file = Path(bids_file).name
473
- dsnumber = bids_dataset.dataset
474
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
475
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
476
-
477
- # Update with actual values where available
478
- try:
479
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
480
- except Exception as e:
481
- logger.error("Error getting participants_tsv: %s", str(e))
482
- participants_tsv = None
483
-
484
- try:
485
- eeg_json = bids_dataset.eeg_json(bids_file)
486
- except Exception as e:
487
- logger.error("Error getting eeg_json: %s", str(e))
488
- eeg_json = None
489
-
490
- bids_dependencies_files = self.config["bids_dependencies_files"]
491
- bidsdependencies = []
492
- for extension in bids_dependencies_files:
493
- try:
494
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
495
- dep_path = [
496
- str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
497
- ]
498
- bidsdependencies.extend(dep_path)
499
- except Exception:
500
- pass
501
-
502
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
503
-
504
- # Define field extraction functions with error handling
505
- field_extractors = {
506
- "data_name": lambda: f"{bids_dataset.dataset}_{file}",
507
- "dataset": lambda: bids_dataset.dataset,
508
- "bidspath": lambda: openneuro_path,
509
- "subject": lambda: bids_dataset.get_bids_file_attribute(
510
- "subject", bids_file
511
- ),
512
- "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
513
- "session": lambda: bids_dataset.get_bids_file_attribute(
514
- "session", bids_file
515
- ),
516
- "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
517
- "modality": lambda: bids_dataset.get_bids_file_attribute(
518
- "modality", bids_file
519
- ),
520
- "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
521
- "sfreq", bids_file
522
- ),
523
- "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
524
- "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
525
- "participant_tsv": lambda: participants_tsv,
526
- "eeg_json": lambda: eeg_json,
527
- "bidsdependencies": lambda: bidsdependencies,
528
- }
529
-
530
- # Dynamically populate attrs with error handling
531
- for field, extractor in field_extractors.items():
532
- try:
533
- attrs[field] = extractor()
534
- except Exception as e:
535
- logger.error("Error extracting %s : %s", field, str(e))
536
- attrs[field] = None
537
-
538
- return attrs
539
-
540
425
  def add_bids_dataset(
541
426
  self, dataset: str, data_dir: str, overwrite: bool = True
542
427
  ) -> None:
543
- """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
544
- under the given dataset name.
428
+ """Scan a local BIDS dataset and upsert records into MongoDB.
545
429
 
546
430
  Parameters
547
431
  ----------
548
- dataset : str)
549
- The name of the dataset to be added (e.g., "ds002718").
432
+ dataset : str
433
+ Dataset identifier (e.g., ``"ds002718"``).
550
434
  data_dir : str
551
- The path to the BIDS dataset directory.
552
- overwrite : bool
553
- Whether to overwrite/update existing records in the database.
435
+ Path to the local BIDS dataset directory.
436
+ overwrite : bool, default True
437
+ If ``True``, update existing records when encountered; otherwise,
438
+ skip records that already exist.
439
+
440
+ Raises
441
+ ------
442
+ ValueError
443
+ If called on a public client ``(is_public=True)``.
554
444
 
555
445
  """
556
446
  if self.is_public:
@@ -565,7 +455,7 @@ class EEGDash:
565
455
  dataset=dataset,
566
456
  )
567
457
  except Exception as e:
568
- logger.error("Error creating bids dataset %s: $s", dataset, str(e))
458
+ logger.error("Error creating bids dataset %s: %s", dataset, str(e))
569
459
  raise e
570
460
  requests = []
571
461
  for bids_file in bids_dataset.get_files():
@@ -574,15 +464,13 @@ class EEGDash:
574
464
 
575
465
  if self.exist({"data_name": data_id}):
576
466
  if overwrite:
577
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
467
+ eeg_attrs = load_eeg_attrs_from_bids_file(
578
468
  bids_dataset, bids_file
579
469
  )
580
- requests.append(self.update_request(eeg_attrs))
470
+ requests.append(self._update_request(eeg_attrs))
581
471
  else:
582
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
583
- bids_dataset, bids_file
584
- )
585
- requests.append(self.add_request(eeg_attrs))
472
+ eeg_attrs = load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
473
+ requests.append(self._add_request(eeg_attrs))
586
474
  except Exception as e:
587
475
  logger.error("Error adding record %s", bids_file)
588
476
  logger.error(str(e))
@@ -598,22 +486,22 @@ class EEGDash:
598
486
  logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
599
487
 
600
488
  def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
601
- """Retrieve a list of EEG data arrays that match the given query. See also
602
- the `find()` method for details on the query format.
489
+ """Download and return EEG data arrays for records matching a query.
603
490
 
604
491
  Parameters
605
492
  ----------
606
493
  query : dict
607
- A dictionary that specifies the query to be executed; this is a reference
608
- document that is used to match records in the MongoDB collection.
494
+ MongoDB query used to select records.
609
495
 
610
496
  Returns
611
497
  -------
612
- A list of xarray DataArray objects containing the EEG data for each matching record.
498
+ list of xr.DataArray
499
+ EEG data for each matching record, with dimensions ``("channel", "time")``.
613
500
 
614
501
  Notes
615
502
  -----
616
- Retrieval is done in parallel, and the downloaded data are not cached locally.
503
+ Retrieval runs in parallel. Downloaded files are read and discarded
504
+ (no on-disk caching here).
617
505
 
618
506
  """
619
507
  sessions = self.find(query)
@@ -623,12 +511,40 @@ class EEGDash:
623
511
  results = Parallel(
624
512
  n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
625
513
  )(
626
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
514
+ delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
627
515
  for session in sessions
628
516
  )
629
517
  return results
630
518
 
631
- def add_request(self, record: dict):
519
+ def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
520
+ """Build an S3 URI from a DB record or a relative path.
521
+
522
+ Parameters
523
+ ----------
524
+ record : dict or str
525
+ Either a DB record containing a ``'bidspath'`` key, or a relative
526
+ path string under the OpenNeuro bucket.
527
+
528
+ Returns
529
+ -------
530
+ str
531
+ Fully qualified S3 URI.
532
+
533
+ Raises
534
+ ------
535
+ ValueError
536
+ If a mapping is provided but ``'bidspath'`` is missing.
537
+
538
+ """
539
+ if isinstance(record, str):
540
+ rel = record
541
+ else:
542
+ rel = record.get("bidspath")
543
+ if not rel:
544
+ raise ValueError("Record missing 'bidspath' for S3 path resolution")
545
+ return f"s3://openneuro.org/{rel}"
546
+
547
+ def _add_request(self, record: dict):
632
548
  """Internal helper method to create a MongoDB insertion request for a record."""
633
549
  return InsertOne(record)
634
550
 
@@ -642,12 +558,19 @@ class EEGDash:
642
558
  except:
643
559
  logger.error("Error adding record: %s ", record["data_name"])
644
560
 
645
- def update_request(self, record: dict):
561
+ def _update_request(self, record: dict):
646
562
  """Internal helper method to create a MongoDB update request for a record."""
647
563
  return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
648
564
 
649
565
  def update(self, record: dict):
650
- """Update a single record in the MongoDB collection."""
566
+ """Update a single record in the MongoDB collection.
567
+
568
+ Parameters
569
+ ----------
570
+ record : dict
571
+ Record content to set at the matching ``data_name``.
572
+
573
+ """
651
574
  try:
652
575
  self.__collection.update_one(
653
576
  {"data_name": record["data_name"]}, {"$set": record}
@@ -655,15 +578,33 @@ class EEGDash:
655
578
  except: # silent failure
656
579
  logger.error("Error updating record: %s", record["data_name"])
657
580
 
581
+ def exists(self, query: dict[str, Any]) -> bool:
582
+ """Alias for :meth:`exist` provided for API clarity."""
583
+ return self.exist(query)
584
+
658
585
  def remove_field(self, record, field):
659
- """Remove a specific field from a record in the MongoDB collection."""
586
+ """Remove a specific field from a record in the MongoDB collection.
587
+
588
+ Parameters
589
+ ----------
590
+ record : dict
591
+ Record identifying object with ``data_name``.
592
+ field : str
593
+ Field name to remove.
594
+
595
+ """
660
596
  self.__collection.update_one(
661
597
  {"data_name": record["data_name"]}, {"$unset": {field: 1}}
662
598
  )
663
599
 
664
600
  def remove_field_from_db(self, field):
665
- """Removed all occurrences of a specific field from all records in the MongoDB
666
- collection. WARNING: this operation is destructive and should be used with caution.
601
+ """Remove a field from all records (destructive).
602
+
603
+ Parameters
604
+ ----------
605
+ field : str
606
+ Field name to remove from every document.
607
+
667
608
  """
668
609
  self.__collection.update_many({}, {"$unset": {field: 1}})
669
610
 
@@ -673,11 +614,13 @@ class EEGDash:
673
614
  return self.__collection
674
615
 
675
616
  def close(self):
676
- """Close the MongoDB client connection.
617
+ """Backward-compatibility no-op; connections are managed globally.
618
+
619
+ Notes
620
+ -----
621
+ Connections are managed by :class:`MongoConnectionManager`. Use
622
+ :meth:`close_all_connections` to explicitly close all clients.
677
623
 
678
- Note: Since MongoDB clients are now managed by a singleton,
679
- this method no longer closes connections. Use close_all_connections()
680
- class method to close all connections if needed.
681
624
  """
682
625
  # Individual instances no longer close the shared client
683
626
  pass
@@ -688,12 +631,77 @@ class EEGDash:
688
631
  MongoConnectionManager.close_all()
689
632
 
690
633
  def __del__(self):
691
- """Ensure connection is closed when object is deleted."""
634
+ """Destructor; no explicit action needed due to global connection manager."""
692
635
  # No longer needed since we're using singleton pattern
693
636
  pass
694
637
 
695
638
 
696
- class EEGDashDataset(BaseConcatDataset):
639
+ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitMeta):
640
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
641
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
642
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
643
+
644
+ Querying Examples:
645
+ ------------------
646
+ # Find by single subject
647
+ >>> ds = EEGDashDataset(dataset="ds005505", subject="NDARCA153NKE")
648
+
649
+ # Find by a list of subjects and a specific task
650
+ >>> subjects = ["NDARCA153NKE", "NDARXT792GY8"]
651
+ >>> ds = EEGDashDataset(dataset="ds005505", subject=subjects, task="RestingState")
652
+
653
+ # Use a raw MongoDB query for advanced filtering
654
+ >>> raw_query = {"dataset": "ds005505", "subject": {"$in": subjects}}
655
+ >>> ds = EEGDashDataset(query=raw_query)
656
+
657
+ Parameters
658
+ ----------
659
+ cache_dir : str | Path
660
+ Directory where data are cached locally. If not specified, a default
661
+ cache directory under the user cache is used.
662
+ query : dict | None
663
+ Raw MongoDB query to filter records. If provided, it is merged with
664
+ keyword filtering arguments (see ``**kwargs``) using logical AND.
665
+ You must provide at least a ``dataset`` (either in ``query`` or
666
+ as a keyword argument). Only fields in ``ALLOWED_QUERY_FIELDS`` are
667
+ considered for filtering.
668
+ dataset : str
669
+ Dataset identifier (e.g., ``"ds002718"``). Required if ``query`` does
670
+ not already specify a dataset.
671
+ task : str | list[str]
672
+ Task name(s) to filter by (e.g., ``"RestingState"``).
673
+ subject : str | list[str]
674
+ Subject identifier(s) to filter by (e.g., ``"NDARCA153NKE"``).
675
+ session : str | list[str]
676
+ Session identifier(s) to filter by (e.g., ``"1"``).
677
+ run : str | list[str]
678
+ Run identifier(s) to filter by (e.g., ``"1"``).
679
+ description_fields : list[str]
680
+ Fields to extract from each record and include in dataset descriptions
681
+ (e.g., "subject", "session", "run", "task").
682
+ s3_bucket : str | None
683
+ Optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
684
+ default OpenNeuro bucket when downloading data files.
685
+ records : list[dict] | None
686
+ Pre-fetched metadata records. If provided, the dataset is constructed
687
+ directly from these records and no MongoDB query is performed.
688
+ download : bool, default True
689
+ If False, load from local BIDS files only. Local data are expected
690
+ under ``cache_dir / dataset``; no DB or S3 access is attempted.
691
+ n_jobs : int
692
+ Number of parallel jobs to use where applicable (-1 uses all cores).
693
+ eeg_dash_instance : EEGDash | None
694
+ Optional existing EEGDash client to reuse for DB queries. If None,
695
+ a new client is created on demand, not used in the case of no download.
696
+ **kwargs : dict
697
+ Additional keyword arguments serving two purposes:
698
+ - Filtering: any keys present in ``ALLOWED_QUERY_FIELDS`` are treated
699
+ as query filters (e.g., ``dataset``, ``subject``, ``task``, ...).
700
+ - Dataset options: remaining keys are forwarded to the
701
+ ``EEGDashBaseDataset`` constructor.
702
+
703
+ """
704
+
697
705
  def __init__(
698
706
  self,
699
707
  cache_dir: str | Path,
@@ -708,162 +716,280 @@ class EEGDashDataset(BaseConcatDataset):
708
716
  "sex",
709
717
  ],
710
718
  s3_bucket: str | None = None,
711
- eeg_dash_instance=None,
712
719
  records: list[dict] | None = None,
713
- offline_mode: bool = False,
720
+ download: bool = True,
714
721
  n_jobs: int = -1,
722
+ eeg_dash_instance: EEGDash | None = None,
715
723
  **kwargs,
716
724
  ):
717
- """Create a new EEGDashDataset from a given query or local BIDS dataset directory
718
- and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
719
- instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
720
-
721
-
722
- Querying Examples:
723
- ------------------
724
- # Find by single subject
725
- >>> ds = EEGDashDataset(dataset="ds005505", subject="NDARCA153NKE")
726
-
727
- # Find by a list of subjects and a specific task
728
- >>> subjects = ["NDARCA153NKE", "NDARXT792GY8"]
729
- >>> ds = EEGDashDataset(dataset="ds005505", subject=subjects, task="RestingState")
730
-
731
- # Use a raw MongoDB query for advanced filtering
732
- >>> raw_query = {"dataset": "ds005505", "subject": {"$in": subjects}}
733
- >>> ds = EEGDashDataset(query=raw_query)
725
+ # Parameters that don't need validation
726
+ _suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
727
+ self.s3_bucket = s3_bucket
728
+ self.records = records
729
+ self.download = download
730
+ self.n_jobs = n_jobs
731
+ self.eeg_dash_instance = eeg_dash_instance or EEGDash()
734
732
 
735
- Parameters
736
- ----------
737
- query : dict | None
738
- A raw MongoDB query dictionary. If provided, keyword arguments for filtering are ignored.
739
- **kwargs : dict
740
- Keyword arguments for filtering (e.g., `subject="X"`, `task=["T1", "T2"]`) and/or
741
- arguments to be passed to the EEGDashBaseDataset constructor (e.g., `subject=...`).
742
- cache_dir : str
743
- A directory where the dataset will be cached locally.
744
- data_dir : str | None
745
- Optionally a string specifying a local BIDS dataset directory from which to load the EEG data files. Exactly one
746
- of query or data_dir must be provided.
747
- dataset : str | None
748
- If data_dir is given, a name for the dataset to be loaded.
749
- description_fields : list[str]
750
- A list of fields to be extracted from the dataset records
751
- and included in the returned data description(s). Examples are typical
752
- subject metadata fields such as "subject", "session", "run", "task", etc.;
753
- see also data_config.description_fields for the default set of fields.
754
- s3_bucket : str | None
755
- An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
756
- default OpenNeuro bucket for loading data files
757
- records : list[dict] | None
758
- Optional list of pre-fetched metadata records. If provided, the dataset is
759
- constructed directly from these records without querying MongoDB.
760
- offline_mode : bool
761
- If True, do not attempt to query MongoDB at all. This is useful if you want to
762
- work with a local cache only, or if you are offline.
763
- n_jobs : int
764
- The number of jobs to run in parallel (default is -1, meaning using all processors).
765
- kwargs : dict
766
- Additional keyword arguments to be passed to the EEGDashBaseDataset
767
- constructor.
733
+ # Resolve a unified cache directory across code/tests/CI
734
+ self.cache_dir = Path(cache_dir or get_default_cache_dir())
768
735
 
769
- """
770
- self.cache_dir = Path(cache_dir or platformdirs.user_cache_dir("EEGDash"))
771
736
  if not self.cache_dir.exists():
772
737
  warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
773
738
  self.cache_dir.mkdir(exist_ok=True, parents=True)
774
- self.s3_bucket = s3_bucket
775
- self.eeg_dash = eeg_dash_instance
776
739
 
777
740
  # Separate query kwargs from other kwargs passed to the BaseDataset constructor
778
741
  self.query = query or {}
779
742
  self.query.update(
780
- {k: v for k, v in kwargs.items() if k in EEGDash._ALLOWED_QUERY_FIELDS}
743
+ {k: v for k, v in kwargs.items() if k in ALLOWED_QUERY_FIELDS}
781
744
  )
782
745
  base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in self.query}
783
746
  if "dataset" not in self.query:
784
- raise ValueError("You must provide a 'dataset' argument")
785
-
786
- self.data_dir = self.cache_dir / self.query["dataset"]
787
- if self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values():
747
+ # If explicit records are provided, infer dataset from records
748
+ if isinstance(records, list) and records and isinstance(records[0], dict):
749
+ inferred = records[0].get("dataset")
750
+ if inferred:
751
+ self.query["dataset"] = inferred
752
+ else:
753
+ raise ValueError("You must provide a 'dataset' argument")
754
+ else:
755
+ raise ValueError("You must provide a 'dataset' argument")
756
+
757
+ # Decide on a dataset subfolder name for cache isolation. If using
758
+ # challenge/preprocessed buckets (e.g., BDF, mini subsets), append
759
+ # informative suffixes to avoid overlapping with the original dataset.
760
+ dataset_folder = self.query["dataset"]
761
+ if self.s3_bucket:
762
+ suffixes: list[str] = []
763
+ bucket_lower = str(self.s3_bucket).lower()
764
+ if "bdf" in bucket_lower:
765
+ suffixes.append("bdf")
766
+ if "mini" in bucket_lower:
767
+ suffixes.append("mini")
768
+ if suffixes:
769
+ dataset_folder = f"{dataset_folder}-{'-'.join(suffixes)}"
770
+
771
+ self.data_dir = self.cache_dir / dataset_folder
772
+
773
+ if (
774
+ not _suppress_comp_warning
775
+ and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
776
+ ):
788
777
  warn(
789
778
  "If you are not participating in the competition, you can ignore this warning!"
790
779
  "\n\n"
791
- "[EEGChallengeDataset] EEG 2025 Competition Data Notice:\n"
792
- "-------------------------------------------------------\n"
780
+ "EEG 2025 Competition Data Notice:\n"
781
+ "---------------------------------\n"
793
782
  " You are loading the dataset that is used in the EEG 2025 Competition:\n"
794
- "IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` directly.\n"
783
+ "IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n"
795
784
  "and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
796
785
  "to allow more people to participate.\n"
797
- "\n",
786
+ "\n"
798
787
  "If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
799
788
  "\n",
800
789
  UserWarning,
801
790
  module="eegdash",
802
791
  )
803
- _owns_client = False
804
- if self.eeg_dash is None and records is None:
805
- self.eeg_dash = EEGDash()
806
- _owns_client = True
792
+ if records is not None:
793
+ self.records = records
794
+ datasets = [
795
+ EEGDashBaseDataset(
796
+ record,
797
+ self.cache_dir,
798
+ self.s3_bucket,
799
+ **base_dataset_kwargs,
800
+ )
801
+ for record in self.records
802
+ ]
803
+ elif not download: # only assume local data is complete if not downloading
804
+ if not self.data_dir.exists():
805
+ raise ValueError(
806
+ f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
807
+ )
808
+ records = self._find_local_bids_records(self.data_dir, self.query)
809
+ # Try to enrich from local participants.tsv to restore requested fields
810
+ try:
811
+ bids_ds = EEGBIDSDataset(
812
+ data_dir=str(self.data_dir), dataset=self.query["dataset"]
813
+ ) # type: ignore[index]
814
+ except Exception:
815
+ bids_ds = None
816
+
817
+ datasets = []
818
+ for record in records:
819
+ # Start with entity values from filename
820
+ desc: dict[str, Any] = {
821
+ k: record.get(k)
822
+ for k in ("subject", "session", "run", "task")
823
+ if record.get(k) is not None
824
+ }
825
+
826
+ if bids_ds is not None:
827
+ try:
828
+ rel_from_dataset = Path(record["bidspath"]).relative_to(
829
+ record["dataset"]
830
+ ) # type: ignore[index]
831
+ local_file = (self.data_dir / rel_from_dataset).as_posix()
832
+ part_row = bids_ds.subject_participant_tsv(local_file)
833
+ desc = merge_participants_fields(
834
+ description=desc,
835
+ participants_row=part_row
836
+ if isinstance(part_row, dict)
837
+ else None,
838
+ description_fields=description_fields,
839
+ )
840
+ except Exception:
841
+ pass
807
842
 
808
- try:
809
- if records is not None:
810
- self.records = records
811
- datasets = [
843
+ datasets.append(
812
844
  EEGDashBaseDataset(
813
- record,
814
- self.cache_dir,
815
- self.s3_bucket,
816
- **base_dataset_kwargs,
817
- )
818
- for record in self.records
819
- ]
820
- elif offline_mode: # only assume local data is complete if in offline mode
821
- if self.data_dir.exists():
822
- # This path loads from a local directory and is not affected by DB query logic
823
- datasets = self.load_bids_daxtaset(
824
- dataset=self.query["dataset"],
825
- data_dir=self.data_dir,
826
- description_fields=description_fields,
827
- s3_bucket=s3_bucket,
828
- n_jobs=n_jobs,
845
+ record=record,
846
+ cache_dir=self.cache_dir,
847
+ s3_bucket=self.s3_bucket,
848
+ description=desc,
829
849
  **base_dataset_kwargs,
830
850
  )
831
- else:
832
- raise ValueError(
833
- f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
834
- )
835
- elif self.query:
836
- # This is the DB query path that we are improving
837
- datasets = self._find_datasets(
838
- query=self.eeg_dash._build_query_from_kwargs(**self.query),
839
- description_fields=description_fields,
840
- base_dataset_kwargs=base_dataset_kwargs,
841
- )
842
- # We only need filesystem if we need to access S3
843
- self.filesystem = S3FileSystem(
844
- anon=True, client_kwargs={"region_name": "us-east-2"}
845
- )
846
- else:
847
- raise ValueError(
848
- "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
849
851
  )
850
- finally:
851
- if _owns_client and self.eeg_dash is not None:
852
- self.eeg_dash.close()
852
+ elif self.query:
853
+ # This is the DB query path that we are improving
854
+ datasets = self._find_datasets(
855
+ query=build_query_from_kwargs(**self.query),
856
+ description_fields=description_fields,
857
+ base_dataset_kwargs=base_dataset_kwargs,
858
+ )
859
+ # We only need filesystem if we need to access S3
860
+ self.filesystem = S3FileSystem(
861
+ anon=True, client_kwargs={"region_name": "us-east-2"}
862
+ )
863
+ else:
864
+ raise ValueError(
865
+ "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
866
+ )
853
867
 
854
868
  super().__init__(datasets)
855
869
 
856
- def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
857
- """Helper to recursively search for a key in a nested dictionary structure; returns
858
- the value associated with the first occurrence of the key, or None if not found.
870
+ def _find_local_bids_records(
871
+ self, dataset_root: Path, filters: dict[str, Any]
872
+ ) -> list[dict]:
873
+ """Discover local BIDS EEG files and build minimal records.
874
+
875
+ This helper enumerates EEG recordings under ``dataset_root`` via
876
+ ``mne_bids.find_matching_paths`` and applies entity filters to produce a
877
+ list of records suitable for ``EEGDashBaseDataset``. No network access
878
+ is performed and files are not read.
879
+
880
+ Parameters
881
+ ----------
882
+ dataset_root : Path
883
+ Local dataset directory. May be the plain dataset folder (e.g.,
884
+ ``ds005509``) or a suffixed cache variant (e.g.,
885
+ ``ds005509-bdf-mini``).
886
+ filters : dict of {str, Any}
887
+ Query filters. Must include ``'dataset'`` with the dataset id (without
888
+ local suffixes). May include BIDS entities ``'subject'``,
889
+ ``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar
890
+ or a sequence of scalars.
891
+
892
+ Returns
893
+ -------
894
+ records : list of dict
895
+ One record per matched EEG file with at least:
896
+
897
+ - ``'data_name'``
898
+ - ``'dataset'`` (dataset id, without suffixes)
899
+ - ``'bidspath'`` (normalized to start with the dataset id)
900
+ - ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None)
901
+ - ``'bidsdependencies'`` (empty list)
902
+ - ``'modality'`` (``"eeg"``)
903
+ - ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal
904
+ defaults for offline usage)
905
+
906
+ Notes
907
+ -----
908
+ - Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
909
+ - ``bidspath`` is constructed as
910
+ ``<dataset_id> / <relative_path_from_dataset_root>`` to ensure the
911
+ first path component is the dataset id (without local cache suffixes).
912
+ - Minimal defaults are set for ``sampling_frequency``, ``nchans``, and
913
+ ``ntimes`` to satisfy dataset length requirements offline.
914
+
859
915
  """
916
+ dataset_id = filters["dataset"]
917
+ arg_map = {
918
+ "subjects": "subject",
919
+ "sessions": "session",
920
+ "tasks": "task",
921
+ "runs": "run",
922
+ }
923
+ matching_args: dict[str, list[str]] = {}
924
+ for finder_key, entity_key in arg_map.items():
925
+ entity_val = filters.get(entity_key)
926
+ if entity_val is None:
927
+ continue
928
+ if isinstance(entity_val, (list, tuple, set)):
929
+ entity_vals = list(entity_val)
930
+ if not entity_vals:
931
+ continue
932
+ matching_args[finder_key] = entity_vals
933
+ else:
934
+ matching_args[finder_key] = [entity_val]
935
+
936
+ matched_paths = find_matching_paths(
937
+ root=str(dataset_root),
938
+ datatypes=["eeg"],
939
+ suffixes=["eeg"],
940
+ ignore_json=True,
941
+ **matching_args,
942
+ )
943
+ records_out: list[dict] = []
944
+
945
+ for bids_path in matched_paths:
946
+ # Build bidspath as dataset_id / relative_path_from_dataset_root (POSIX)
947
+ rel_from_root = (
948
+ Path(bids_path.fpath)
949
+ .resolve()
950
+ .relative_to(Path(bids_path.root).resolve())
951
+ )
952
+ bidspath = f"{dataset_id}/{rel_from_root.as_posix()}"
953
+
954
+ rec = {
955
+ "data_name": f"{dataset_id}_{Path(bids_path.fpath).name}",
956
+ "dataset": dataset_id,
957
+ "bidspath": bidspath,
958
+ "subject": (bids_path.subject or None),
959
+ "session": (bids_path.session or None),
960
+ "task": (bids_path.task or None),
961
+ "run": (bids_path.run or None),
962
+ # minimal fields to satisfy BaseDataset from eegdash
963
+ "bidsdependencies": [], # not needed to just run.
964
+ "modality": "eeg",
965
+ # minimal numeric defaults for offline length calculation
966
+ "sampling_frequency": None,
967
+ "nchans": None,
968
+ "ntimes": None,
969
+ }
970
+ records_out.append(rec)
971
+
972
+ return records_out
973
+
974
+ def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
975
+ """Recursively search for target_key in nested dicts/lists with normalized matching.
976
+
977
+ This makes lookups tolerant to naming differences like "p-factor" vs "p_factor".
978
+ Returns the first match or None.
979
+ """
980
+ norm_target = normalize_key(target_key)
860
981
  if isinstance(data, dict):
861
- if target_key in data:
862
- return data[target_key]
863
- for value in data.values():
864
- result = self.find_key_in_nested_dict(value, target_key)
865
- if result is not None:
866
- return result
982
+ for k, v in data.items():
983
+ if normalize_key(k) == norm_target:
984
+ return v
985
+ res = self._find_key_in_nested_dict(v, target_key)
986
+ if res is not None:
987
+ return res
988
+ elif isinstance(data, list):
989
+ for item in data:
990
+ res = self._find_key_in_nested_dict(item, target_key)
991
+ if res is not None:
992
+ return res
867
993
  return None
868
994
 
869
995
  def _find_datasets(
@@ -892,15 +1018,23 @@ class EEGDashDataset(BaseConcatDataset):
892
1018
 
893
1019
  """
894
1020
  datasets: list[EEGDashBaseDataset] = []
895
-
896
- self.records = self.eeg_dash.find(query)
1021
+ self.records = self.eeg_dash_instance.find(query)
897
1022
 
898
1023
  for record in self.records:
899
- description = {}
1024
+ description: dict[str, Any] = {}
1025
+ # Requested fields first (normalized matching)
900
1026
  for field in description_fields:
901
- value = self.find_key_in_nested_dict(record, field)
1027
+ value = self._find_key_in_nested_dict(record, field)
902
1028
  if value is not None:
903
1029
  description[field] = value
1030
+ # Merge all participants.tsv columns generically
1031
+ part = self._find_key_in_nested_dict(record, "participant_tsv")
1032
+ if isinstance(part, dict):
1033
+ description = merge_participants_fields(
1034
+ description=description,
1035
+ participants_row=part,
1036
+ description_fields=description_fields,
1037
+ )
904
1038
  datasets.append(
905
1039
  EEGDashBaseDataset(
906
1040
  record,
@@ -911,69 +1045,3 @@ class EEGDashDataset(BaseConcatDataset):
911
1045
  )
912
1046
  )
913
1047
  return datasets
914
-
915
- def load_bids_dataset(
916
- self,
917
- dataset: str,
918
- data_dir: str | Path,
919
- description_fields: list[str],
920
- s3_bucket: str | None = None,
921
- n_jobs: int = -1,
922
- **kwargs,
923
- ):
924
- """Helper method to load a single local BIDS dataset and return it as a list of
925
- EEGDashBaseDatasets (one for each recording in the dataset).
926
-
927
- Parameters
928
- ----------
929
- dataset : str
930
- A name for the dataset to be loaded (e.g., "ds002718").
931
- data_dir : str
932
- The path to the local BIDS dataset directory.
933
- description_fields : list[str]
934
- A list of fields to be extracted from the dataset records
935
- and included in the returned dataset description(s).
936
- s3_bucket : str | None
937
- The S3 bucket to upload the dataset files to (if any).
938
- n_jobs : int
939
- The number of jobs to run in parallel (default is -1, meaning using all processors).
940
-
941
- """
942
- bids_dataset = EEGBIDSDataset(
943
- data_dir=data_dir,
944
- dataset=dataset,
945
- )
946
- datasets = Parallel(n_jobs=n_jobs, prefer="threads", verbose=1)(
947
- delayed(self.get_base_dataset_from_bids_file)(
948
- bids_dataset=bids_dataset,
949
- bids_file=bids_file,
950
- s3_bucket=s3_bucket,
951
- description_fields=description_fields,
952
- **kwargs,
953
- )
954
- for bids_file in bids_dataset.get_files()
955
- )
956
- return datasets
957
-
958
- def get_base_dataset_from_bids_file(
959
- self,
960
- bids_dataset: "EEGBIDSDataset",
961
- bids_file: str,
962
- s3_bucket: str | None,
963
- description_fields: list[str],
964
- **kwargs,
965
- ) -> "EEGDashBaseDataset":
966
- """Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
967
- record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
968
- description = {}
969
- for field in description_fields:
970
- value = self.find_key_in_nested_dict(record, field)
971
- if value is not None:
972
- description[field] = value
973
- return EEGDashBaseDataset(
974
- record,
975
- self.cache_dir,
976
- s3_bucket,
977
- description=description,
978
- **kwargs,
979
- )