eegdash 0.3.6.dev182011805__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/api.py CHANGED
@@ -3,66 +3,67 @@ import os
3
3
  import tempfile
4
4
  from pathlib import Path
5
5
  from typing import Any, Mapping
6
+ from urllib.parse import urlsplit
6
7
 
7
8
  import mne
8
9
  import numpy as np
9
10
  import xarray as xr
10
11
  from dotenv import load_dotenv
11
12
  from joblib import Parallel, delayed
12
- from mne_bids import get_bids_path_from_fname, read_raw_bids
13
+ from mne.utils import warn
14
+ from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
13
15
  from pymongo import InsertOne, UpdateOne
14
16
  from s3fs import S3FileSystem
15
17
 
16
18
  from braindecode.datasets import BaseConcatDataset
17
19
 
18
- from .data_config import config as data_config
19
- from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
20
+ from .bids_eeg_metadata import (
21
+ build_query_from_kwargs,
22
+ load_eeg_attrs_from_bids_file,
23
+ merge_participants_fields,
24
+ normalize_key,
25
+ )
26
+ from .const import (
27
+ ALLOWED_QUERY_FIELDS,
28
+ RELEASE_TO_OPENNEURO_DATASET_MAP,
29
+ )
30
+ from .const import config as data_config
31
+ from .data_utils import (
32
+ EEGBIDSDataset,
33
+ EEGDashBaseDataset,
34
+ )
20
35
  from .mongodb import MongoConnectionManager
36
+ from .paths import get_default_cache_dir
21
37
 
22
38
  logger = logging.getLogger("eegdash")
23
39
 
24
40
 
25
41
  class EEGDash:
26
- """A high-level interface to the EEGDash database.
42
+ """High-level interface to the EEGDash metadata database.
27
43
 
28
- This class is primarily used to interact with the metadata records stored in the
29
- EEGDash database (or a private instance of it), allowing users to find, add, and
30
- update EEG data records.
31
-
32
- While this class provides basic support for loading EEG data, please see
33
- the EEGDashDataset class for a more complete way to retrieve and work with full
34
- datasets.
44
+ Provides methods to query, insert, and update metadata records stored in the
45
+ EEGDash MongoDB database (public or private). Also includes utilities to load
46
+ EEG data from S3 for matched records.
35
47
 
48
+ For working with collections of
49
+ recordings as PyTorch datasets, prefer :class:`EEGDashDataset`.
36
50
  """
37
51
 
38
- _ALLOWED_QUERY_FIELDS = {
39
- "data_name",
40
- "dataset",
41
- "subject",
42
- "task",
43
- "session",
44
- "run",
45
- "modality",
46
- "sampling_frequency",
47
- "nchans",
48
- "ntimes",
49
- }
50
-
51
52
  def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
52
- """Create new instance of the EEGDash Database client.
53
+ """Create a new EEGDash client.
53
54
 
54
55
  Parameters
55
56
  ----------
56
- is_public: bool
57
- Whether to connect to the public MongoDB database; if False, connect to a
58
- private database instance as per the DB_CONNECTION_STRING env variable
59
- (or .env file entry).
60
- is_staging: bool
61
- If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
62
- production database ("eegdash").
63
-
64
- Example
65
- -------
57
+ is_public : bool, default True
58
+ Connect to the public MongoDB database. If ``False``, connect to a
59
+ private database instance using the ``DB_CONNECTION_STRING`` environment
60
+ variable (or value from a ``.env`` file).
61
+ is_staging : bool, default False
62
+ If ``True``, use the staging database (``eegdashstaging``); otherwise
63
+ use the production database (``eegdash``).
64
+
65
+ Examples
66
+ --------
66
67
  >>> eegdash = EEGDash()
67
68
 
68
69
  """
@@ -103,23 +104,25 @@ class EEGDash:
103
104
 
104
105
  Parameters
105
106
  ----------
106
- query: dict, optional
107
- A complete MongoDB query dictionary. This is a positional-only argument.
108
- **kwargs:
109
- Keyword arguments representing field-value pairs for the query.
110
- Values can be single items (str, int) or lists of items for multi-search.
107
+ query : dict, optional
108
+ Complete MongoDB query dictionary. This is a positional-only
109
+ argument.
110
+ **kwargs
111
+ User-friendly field filters that are converted to a MongoDB query.
112
+ Values can be scalars (e.g., ``"sub-01"``) or sequences (translated
113
+ to ``$in`` queries).
111
114
 
112
115
  Returns
113
116
  -------
114
- list:
115
- A list of DB records (string-keyed dictionaries) that match the query.
117
+ list of dict
118
+ DB records that match the query.
116
119
 
117
120
  """
118
121
  final_query: dict[str, Any] | None = None
119
122
 
120
123
  # Accept explicit empty dict {} to mean "match all"
121
124
  raw_query = query if isinstance(query, dict) else None
122
- kwargs_query = self._build_query_from_kwargs(**kwargs) if kwargs else None
125
+ kwargs_query = build_query_from_kwargs(**kwargs) if kwargs else None
123
126
 
124
127
  # Determine presence, treating {} as a valid raw query
125
128
  has_raw = isinstance(raw_query, dict)
@@ -236,59 +239,12 @@ class EEGDash:
236
239
  return record
237
240
 
238
241
  def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
239
- """Build and validate a MongoDB query from user-friendly keyword arguments.
242
+ """Internal helper to build a validated MongoDB query from keyword args.
240
243
 
241
- Improvements:
242
- - Reject None values and empty/whitespace-only strings
243
- - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
244
- - Preserve scalars as exact matches
244
+ This delegates to the module-level builder used across the package and
245
+ is exposed here for testing and convenience.
245
246
  """
246
- # 1. Validate that all provided keys are allowed for querying
247
- unknown_fields = set(kwargs.keys()) - self._ALLOWED_QUERY_FIELDS
248
- if unknown_fields:
249
- raise ValueError(
250
- f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
251
- f"Allowed fields are: {', '.join(sorted(self._ALLOWED_QUERY_FIELDS))}"
252
- )
253
-
254
- # 2. Construct the query dictionary
255
- query = {}
256
- for key, value in kwargs.items():
257
- # None is not a valid constraint
258
- if value is None:
259
- raise ValueError(
260
- f"Received None for query parameter '{key}'. Provide a concrete value."
261
- )
262
-
263
- # Handle list-like values as multi-constraints
264
- if isinstance(value, (list, tuple, set)):
265
- cleaned: list[Any] = []
266
- for item in value:
267
- if item is None:
268
- continue
269
- if isinstance(item, str):
270
- item = item.strip()
271
- if not item:
272
- continue
273
- cleaned.append(item)
274
- # Deduplicate while preserving order
275
- cleaned = list(dict.fromkeys(cleaned))
276
- if not cleaned:
277
- raise ValueError(
278
- f"Received an empty list for query parameter '{key}'. This is not supported."
279
- )
280
- query[key] = {"$in": cleaned}
281
- else:
282
- # Scalars: trim strings and validate
283
- if isinstance(value, str):
284
- value = value.strip()
285
- if not value:
286
- raise ValueError(
287
- f"Received an empty string for query parameter '{key}'."
288
- )
289
- query[key] = value
290
-
291
- return query
247
+ return build_query_from_kwargs(**kwargs)
292
248
 
293
249
  # --- Query merging and conflict detection helpers ---
294
250
  def _extract_simple_constraint(self, query: dict[str, Any], key: str):
@@ -321,8 +277,8 @@ class EEGDash:
321
277
  return
322
278
 
323
279
  # Only consider fields we generally allow; skip meta operators like $and
324
- raw_keys = set(raw_query.keys()) & self._ALLOWED_QUERY_FIELDS
325
- kw_keys = set(kwargs_query.keys()) & self._ALLOWED_QUERY_FIELDS
280
+ raw_keys = set(raw_query.keys()) & ALLOWED_QUERY_FIELDS
281
+ kw_keys = set(kwargs_query.keys()) & ALLOWED_QUERY_FIELDS
326
282
  dup_keys = raw_keys & kw_keys
327
283
  for key in dup_keys:
328
284
  rc = self._extract_simple_constraint(raw_query, key)
@@ -357,44 +313,95 @@ class EEGDash:
357
313
  )
358
314
 
359
315
  def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
360
- """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
316
+ """Load EEG data from an S3 URI into an ``xarray.DataArray``.
317
+
318
+ Preserves the original filename, downloads sidecar files when applicable
319
+ (e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
320
+ MNE's direct readers.
361
321
 
362
322
  Parameters
363
323
  ----------
364
324
  s3path : str
365
- An S3 URI (should start with "s3://") for the file in question.
325
+ An S3 URI (should start with "s3://").
366
326
 
367
327
  Returns
368
328
  -------
369
329
  xr.DataArray
370
- A DataArray containing the EEG data, with dimensions "channel" and "time".
330
+ EEG data with dimensions ``("channel", "time")``.
371
331
 
372
- Example
373
- -------
374
- >>> eegdash = EEGDash()
375
- >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
376
- >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
332
+ Raises
333
+ ------
334
+ ValueError
335
+ If the file extension is unsupported.
377
336
 
378
337
  """
379
- with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
380
- with self.filesystem.open(s3path) as s3_file:
381
- tmp.write(s3_file.read())
382
- tmp_path = tmp.name
383
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
384
- os.unlink(tmp_path)
385
- return eeg_data
338
+ # choose a temp dir so sidecars can be colocated
339
+ with tempfile.TemporaryDirectory() as tmpdir:
340
+ # Derive local filenames from the S3 key to keep base name consistent
341
+ s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
342
+ basename = Path(s3_key).name
343
+ ext = Path(basename).suffix.lower()
344
+ local_main = Path(tmpdir) / basename
345
+
346
+ # Download main file
347
+ with (
348
+ self.filesystem.open(s3path, mode="rb") as fsrc,
349
+ open(local_main, "wb") as fdst,
350
+ ):
351
+ fdst.write(fsrc.read())
352
+
353
+ # Determine and fetch any required sidecars
354
+ sidecars: list[str] = []
355
+ if ext == ".set": # EEGLAB
356
+ sidecars = [".fdt"]
357
+ elif ext == ".vhdr": # BrainVision
358
+ sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
359
+
360
+ for sc_ext in sidecars:
361
+ sc_key = s3_key[: -len(ext)] + sc_ext
362
+ sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
363
+ try:
364
+ # If sidecar exists, download next to the main file
365
+ info = self.filesystem.info(sc_uri)
366
+ if info:
367
+ sc_local = Path(tmpdir) / Path(sc_key).name
368
+ with (
369
+ self.filesystem.open(sc_uri, mode="rb") as fsrc,
370
+ open(sc_local, "wb") as fdst,
371
+ ):
372
+ fdst.write(fsrc.read())
373
+ except Exception:
374
+ # Sidecar not present; skip silently
375
+ pass
376
+
377
+ # Read using appropriate MNE reader
378
+ raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
379
+
380
+ data = raw.get_data()
381
+ fs = raw.info["sfreq"]
382
+ max_time = data.shape[1] / fs
383
+ time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
384
+ channel_names = raw.ch_names
385
+
386
+ return xr.DataArray(
387
+ data=data,
388
+ dims=["channel", "time"],
389
+ coords={"time": time_steps, "channel": channel_names},
390
+ )
386
391
 
387
392
  def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
388
- """Load EEG data from a local file and return it as a xarray DataArray.
393
+ """Load EEG data from a local BIDS-formatted file.
389
394
 
390
395
  Parameters
391
396
  ----------
392
397
  bids_file : str
393
- Path to the BIDS-compliant file on the local filesystem.
398
+ Path to a BIDS-compliant EEG file (e.g., ``*_eeg.edf``, ``*_eeg.bdf``,
399
+ ``*_eeg.vhdr``, ``*_eeg.set``).
394
400
 
395
- Notes
396
- -----
397
- Currently, only non-epoched .set files are supported.
401
+ Returns
402
+ -------
403
+ xr.DataArray
404
+ EEG data with dimensions ``("channel", "time")``.
398
405
 
399
406
  """
400
407
  bids_path = get_bids_path_from_fname(bids_file, verbose=False)
@@ -414,140 +421,25 @@ class EEGDash:
414
421
  )
415
422
  return eeg_xarray
416
423
 
417
- def get_raw_extensions(
418
- self, bids_file: str, bids_dataset: EEGBIDSDataset
419
- ) -> list[str]:
420
- """Helper to find paths to additional "sidecar" files that may be associated
421
- with a given main data file in a BIDS dataset; paths are returned as relative to
422
- the parent dataset path.
423
-
424
- For example, if the input file is a .set file, this will return the relative path
425
- to a corresponding .fdt file (if any).
426
- """
427
- bids_file = Path(bids_file)
428
- extensions = {
429
- ".set": [".set", ".fdt"], # eeglab
430
- ".edf": [".edf"], # european
431
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
432
- ".bdf": [".bdf"], # biosemi
433
- }
434
- return [
435
- str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
436
- for suffix in extensions[bids_file.suffix]
437
- if bids_file.with_suffix(suffix).exists()
438
- ]
439
-
440
- def load_eeg_attrs_from_bids_file(
441
- self, bids_dataset: EEGBIDSDataset, bids_file: str
442
- ) -> dict[str, Any]:
443
- """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
444
-
445
- Attributes are at least the ones defined in data_config attributes (set to None if missing),
446
- but are typically a superset, and include, among others, the paths to relevant
447
- meta-data files needed to load and interpret the file in question.
448
-
449
- Parameters
450
- ----------
451
- bids_dataset : EEGBIDSDataset
452
- The BIDS dataset object containing the file.
453
- bids_file : str
454
- The path to the BIDS file within the dataset.
455
-
456
- Returns
457
- -------
458
- dict:
459
- A dictionary representing the metadata record for the given file. This is the
460
- same format as the records stored in the database.
461
-
462
- """
463
- if bids_file not in bids_dataset.files:
464
- raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
465
-
466
- # Initialize attrs with None values for all expected fields
467
- attrs = {field: None for field in self.config["attributes"].keys()}
468
-
469
- file = Path(bids_file).name
470
- dsnumber = bids_dataset.dataset
471
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
472
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
473
-
474
- # Update with actual values where available
475
- try:
476
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
477
- except Exception as e:
478
- logger.error("Error getting participants_tsv: %s", str(e))
479
- participants_tsv = None
480
-
481
- try:
482
- eeg_json = bids_dataset.eeg_json(bids_file)
483
- except Exception as e:
484
- logger.error("Error getting eeg_json: %s", str(e))
485
- eeg_json = None
486
-
487
- bids_dependencies_files = self.config["bids_dependencies_files"]
488
- bidsdependencies = []
489
- for extension in bids_dependencies_files:
490
- try:
491
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
492
- dep_path = [
493
- str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
494
- ]
495
- bidsdependencies.extend(dep_path)
496
- except Exception:
497
- pass
498
-
499
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
500
-
501
- # Define field extraction functions with error handling
502
- field_extractors = {
503
- "data_name": lambda: f"{bids_dataset.dataset}_{file}",
504
- "dataset": lambda: bids_dataset.dataset,
505
- "bidspath": lambda: openneuro_path,
506
- "subject": lambda: bids_dataset.get_bids_file_attribute(
507
- "subject", bids_file
508
- ),
509
- "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
510
- "session": lambda: bids_dataset.get_bids_file_attribute(
511
- "session", bids_file
512
- ),
513
- "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
514
- "modality": lambda: bids_dataset.get_bids_file_attribute(
515
- "modality", bids_file
516
- ),
517
- "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
518
- "sfreq", bids_file
519
- ),
520
- "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
521
- "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
522
- "participant_tsv": lambda: participants_tsv,
523
- "eeg_json": lambda: eeg_json,
524
- "bidsdependencies": lambda: bidsdependencies,
525
- }
526
-
527
- # Dynamically populate attrs with error handling
528
- for field, extractor in field_extractors.items():
529
- try:
530
- attrs[field] = extractor()
531
- except Exception as e:
532
- logger.error("Error extracting %s : %s", field, str(e))
533
- attrs[field] = None
534
-
535
- return attrs
536
-
537
424
  def add_bids_dataset(
538
425
  self, dataset: str, data_dir: str, overwrite: bool = True
539
426
  ) -> None:
540
- """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
541
- under the given dataset name.
427
+ """Scan a local BIDS dataset and upsert records into MongoDB.
542
428
 
543
429
  Parameters
544
430
  ----------
545
- dataset : str)
546
- The name of the dataset to be added (e.g., "ds002718").
431
+ dataset : str
432
+ Dataset identifier (e.g., ``"ds002718"``).
547
433
  data_dir : str
548
- The path to the BIDS dataset directory.
549
- overwrite : bool
550
- Whether to overwrite/update existing records in the database.
434
+ Path to the local BIDS dataset directory.
435
+ overwrite : bool, default True
436
+ If ``True``, update existing records when encountered; otherwise,
437
+ skip records that already exist.
438
+
439
+ Raises
440
+ ------
441
+ ValueError
442
+ If called on a public client ``(is_public=True)``.
551
443
 
552
444
  """
553
445
  if self.is_public:
@@ -562,7 +454,7 @@ class EEGDash:
562
454
  dataset=dataset,
563
455
  )
564
456
  except Exception as e:
565
- logger.error("Error creating bids dataset %s: $s", dataset, str(e))
457
+ logger.error("Error creating bids dataset %s: %s", dataset, str(e))
566
458
  raise e
567
459
  requests = []
568
460
  for bids_file in bids_dataset.get_files():
@@ -571,15 +463,13 @@ class EEGDash:
571
463
 
572
464
  if self.exist({"data_name": data_id}):
573
465
  if overwrite:
574
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
466
+ eeg_attrs = load_eeg_attrs_from_bids_file(
575
467
  bids_dataset, bids_file
576
468
  )
577
- requests.append(self.update_request(eeg_attrs))
469
+ requests.append(self._update_request(eeg_attrs))
578
470
  else:
579
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
580
- bids_dataset, bids_file
581
- )
582
- requests.append(self.add_request(eeg_attrs))
471
+ eeg_attrs = load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
472
+ requests.append(self._add_request(eeg_attrs))
583
473
  except Exception as e:
584
474
  logger.error("Error adding record %s", bids_file)
585
475
  logger.error(str(e))
@@ -595,22 +485,22 @@ class EEGDash:
595
485
  logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
596
486
 
597
487
  def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
598
- """Retrieve a list of EEG data arrays that match the given query. See also
599
- the `find()` method for details on the query format.
488
+ """Download and return EEG data arrays for records matching a query.
600
489
 
601
490
  Parameters
602
491
  ----------
603
492
  query : dict
604
- A dictionary that specifies the query to be executed; this is a reference
605
- document that is used to match records in the MongoDB collection.
493
+ MongoDB query used to select records.
606
494
 
607
495
  Returns
608
496
  -------
609
- A list of xarray DataArray objects containing the EEG data for each matching record.
497
+ list of xr.DataArray
498
+ EEG data for each matching record, with dimensions ``("channel", "time")``.
610
499
 
611
500
  Notes
612
501
  -----
613
- Retrieval is done in parallel, and the downloaded data are not cached locally.
502
+ Retrieval runs in parallel. Downloaded files are read and discarded
503
+ (no on-disk caching here).
614
504
 
615
505
  """
616
506
  sessions = self.find(query)
@@ -620,12 +510,40 @@ class EEGDash:
620
510
  results = Parallel(
621
511
  n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
622
512
  )(
623
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
513
+ delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
624
514
  for session in sessions
625
515
  )
626
516
  return results
627
517
 
628
- def add_request(self, record: dict):
518
+ def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
519
+ """Build an S3 URI from a DB record or a relative path.
520
+
521
+ Parameters
522
+ ----------
523
+ record : dict or str
524
+ Either a DB record containing a ``'bidspath'`` key, or a relative
525
+ path string under the OpenNeuro bucket.
526
+
527
+ Returns
528
+ -------
529
+ str
530
+ Fully qualified S3 URI.
531
+
532
+ Raises
533
+ ------
534
+ ValueError
535
+ If a mapping is provided but ``'bidspath'`` is missing.
536
+
537
+ """
538
+ if isinstance(record, str):
539
+ rel = record
540
+ else:
541
+ rel = record.get("bidspath")
542
+ if not rel:
543
+ raise ValueError("Record missing 'bidspath' for S3 path resolution")
544
+ return f"s3://openneuro.org/{rel}"
545
+
546
+ def _add_request(self, record: dict):
629
547
  """Internal helper method to create a MongoDB insertion request for a record."""
630
548
  return InsertOne(record)
631
549
 
@@ -639,12 +557,19 @@ class EEGDash:
639
557
  except:
640
558
  logger.error("Error adding record: %s ", record["data_name"])
641
559
 
642
- def update_request(self, record: dict):
560
+ def _update_request(self, record: dict):
643
561
  """Internal helper method to create a MongoDB update request for a record."""
644
562
  return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
645
563
 
646
564
  def update(self, record: dict):
647
- """Update a single record in the MongoDB collection."""
565
+ """Update a single record in the MongoDB collection.
566
+
567
+ Parameters
568
+ ----------
569
+ record : dict
570
+ Record content to set at the matching ``data_name``.
571
+
572
+ """
648
573
  try:
649
574
  self.__collection.update_one(
650
575
  {"data_name": record["data_name"]}, {"$set": record}
@@ -652,15 +577,33 @@ class EEGDash:
652
577
  except: # silent failure
653
578
  logger.error("Error updating record: %s", record["data_name"])
654
579
 
580
+ def exists(self, query: dict[str, Any]) -> bool:
581
+ """Alias for :meth:`exist` provided for API clarity."""
582
+ return self.exist(query)
583
+
655
584
  def remove_field(self, record, field):
656
- """Remove a specific field from a record in the MongoDB collection."""
585
+ """Remove a specific field from a record in the MongoDB collection.
586
+
587
+ Parameters
588
+ ----------
589
+ record : dict
590
+ Record identifying object with ``data_name``.
591
+ field : str
592
+ Field name to remove.
593
+
594
+ """
657
595
  self.__collection.update_one(
658
596
  {"data_name": record["data_name"]}, {"$unset": {field: 1}}
659
597
  )
660
598
 
661
599
  def remove_field_from_db(self, field):
662
- """Removed all occurrences of a specific field from all records in the MongoDB
663
- collection. WARNING: this operation is destructive and should be used with caution.
600
+ """Remove a field from all records (destructive).
601
+
602
+ Parameters
603
+ ----------
604
+ field : str
605
+ Field name to remove from every document.
606
+
664
607
  """
665
608
  self.__collection.update_many({}, {"$unset": {field: 1}})
666
609
 
@@ -670,11 +613,13 @@ class EEGDash:
670
613
  return self.__collection
671
614
 
672
615
  def close(self):
673
- """Close the MongoDB client connection.
616
+ """Backward-compatibility no-op; connections are managed globally.
617
+
618
+ Notes
619
+ -----
620
+ Connections are managed by :class:`MongoConnectionManager`. Use
621
+ :meth:`close_all_connections` to explicitly close all clients.
674
622
 
675
- Note: Since MongoDB clients are now managed by a singleton,
676
- this method no longer closes connections. Use close_all_connections()
677
- class method to close all connections if needed.
678
623
  """
679
624
  # Individual instances no longer close the shared client
680
625
  pass
@@ -685,7 +630,7 @@ class EEGDash:
685
630
  MongoConnectionManager.close_all()
686
631
 
687
632
  def __del__(self):
688
- """Ensure connection is closed when object is deleted."""
633
+ """Destructor; no explicit action needed due to global connection manager."""
689
634
  # No longer needed since we're using singleton pattern
690
635
  pass
691
636
 
@@ -693,9 +638,8 @@ class EEGDash:
693
638
  class EEGDashDataset(BaseConcatDataset):
694
639
  def __init__(
695
640
  self,
696
- query: dict | None = None,
697
- cache_dir: str = "~/eegdash_cache",
698
- dataset: str | list[str] | None = None,
641
+ cache_dir: str | Path,
642
+ query: dict[str, Any] = None,
699
643
  description_fields: list[str] = [
700
644
  "subject",
701
645
  "session",
@@ -706,16 +650,16 @@ class EEGDashDataset(BaseConcatDataset):
706
650
  "sex",
707
651
  ],
708
652
  s3_bucket: str | None = None,
709
- data_dir: str | None = None,
710
- eeg_dash_instance=None,
711
653
  records: list[dict] | None = None,
654
+ download: bool = True,
655
+ n_jobs: int = -1,
656
+ eeg_dash_instance: EEGDash | None = None,
712
657
  **kwargs,
713
658
  ):
714
659
  """Create a new EEGDashDataset from a given query or local BIDS dataset directory
715
660
  and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
716
661
  instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
717
662
 
718
-
719
663
  Querying Examples:
720
664
  ------------------
721
665
  # Find by single subject
@@ -731,140 +675,314 @@ class EEGDashDataset(BaseConcatDataset):
731
675
 
732
676
  Parameters
733
677
  ----------
678
+ cache_dir : str | Path
679
+ Directory where data are cached locally. If not specified, a default
680
+ cache directory under the user cache is used.
734
681
  query : dict | None
735
- A raw MongoDB query dictionary. If provided, keyword arguments for filtering are ignored.
736
- **kwargs : dict
737
- Keyword arguments for filtering (e.g., `subject="X"`, `task=["T1", "T2"]`) and/or
738
- arguments to be passed to the EEGDashBaseDataset constructor (e.g., `subject=...`).
739
- cache_dir : str
740
- A directory where the dataset will be cached locally.
741
- data_dir : str | None
742
- Optionally a string specifying a local BIDS dataset directory from which to load the EEG data files. Exactly one
743
- of query or data_dir must be provided.
744
- dataset : str | None
745
- If data_dir is given, a name for the dataset to be loaded.
682
+ Raw MongoDB query to filter records. If provided, it is merged with
683
+ keyword filtering arguments (see ``**kwargs``) using logical AND.
684
+ You must provide at least a ``dataset`` (either in ``query`` or
685
+ as a keyword argument). Only fields in ``ALLOWED_QUERY_FIELDS`` are
686
+ considered for filtering.
746
687
  description_fields : list[str]
747
- A list of fields to be extracted from the dataset records
748
- and included in the returned data description(s). Examples are typical
749
- subject metadata fields such as "subject", "session", "run", "task", etc.;
750
- see also data_config.description_fields for the default set of fields.
688
+ Fields to extract from each record and include in dataset descriptions
689
+ (e.g., "subject", "session", "run", "task").
751
690
  s3_bucket : str | None
752
- An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
753
- default OpenNeuro bucket for loading data files
691
+ Optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
692
+ default OpenNeuro bucket when downloading data files.
754
693
  records : list[dict] | None
755
- Optional list of pre-fetched metadata records. If provided, the dataset is
756
- constructed directly from these records without querying MongoDB.
757
- kwargs : dict
758
- Additional keyword arguments to be passed to the EEGDashBaseDataset
759
- constructor.
694
+ Pre-fetched metadata records. If provided, the dataset is constructed
695
+ directly from these records and no MongoDB query is performed.
696
+ download : bool, default True
697
+ If False, load from local BIDS files only. Local data are expected
698
+ under ``cache_dir / dataset``; no DB or S3 access is attempted.
699
+ n_jobs : int
700
+ Number of parallel jobs to use where applicable (-1 uses all cores).
701
+ eeg_dash_instance : EEGDash | None
702
+ Optional existing EEGDash client to reuse for DB queries. If None,
703
+ a new client is created on demand, not used in the case of no download.
704
+ **kwargs : dict
705
+ Additional keyword arguments serving two purposes:
706
+ - Filtering: any keys present in ``ALLOWED_QUERY_FIELDS`` are treated
707
+ as query filters (e.g., ``dataset``, ``subject``, ``task``, ...).
708
+ - Dataset options: remaining keys are forwarded to the
709
+ ``EEGDashBaseDataset`` constructor.
760
710
 
761
711
  """
762
- self.cache_dir = cache_dir
712
+ # Parameters that don't need validation
713
+ _suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
763
714
  self.s3_bucket = s3_bucket
764
- self.eeg_dash = eeg_dash_instance
765
- _owns_client = False
766
- if self.eeg_dash is None and records is None:
767
- self.eeg_dash = EEGDash()
768
- _owns_client = True
715
+ self.records = records
716
+ self.download = download
717
+ self.n_jobs = n_jobs
718
+ self.eeg_dash_instance = eeg_dash_instance or EEGDash()
769
719
 
770
- # Separate query kwargs from other kwargs passed to the BaseDataset constructor
771
- query_kwargs = {
772
- k: v for k, v in kwargs.items() if k in EEGDash._ALLOWED_QUERY_FIELDS
773
- }
774
- base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in query_kwargs}
720
+ # Resolve a unified cache directory across code/tests/CI
721
+ self.cache_dir = Path(cache_dir or get_default_cache_dir())
775
722
 
776
- # If user provided a dataset name via the dedicated parameter (and we're not
777
- # loading from a local directory), treat it as a query filter. Accept str or list.
778
- if data_dir is None and dataset is not None:
779
- # Allow callers to pass a single dataset id (str) or a list of them.
780
- # If list is provided, let _build_query_from_kwargs turn it into $in later.
781
- query_kwargs.setdefault("dataset", dataset)
723
+ if not self.cache_dir.exists():
724
+ warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
725
+ self.cache_dir.mkdir(exist_ok=True, parents=True)
782
726
 
783
- # Allow mixing raw DB query with additional keyword filters. Both will be
784
- # merged by EEGDash.find() (logical AND), so we do not raise here.
727
+ # Separate query kwargs from other kwargs passed to the BaseDataset constructor
728
+ self.query = query or {}
729
+ self.query.update(
730
+ {k: v for k, v in kwargs.items() if k in ALLOWED_QUERY_FIELDS}
731
+ )
732
+ base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in self.query}
733
+ if "dataset" not in self.query:
734
+ # If explicit records are provided, infer dataset from records
735
+ if isinstance(records, list) and records and isinstance(records[0], dict):
736
+ inferred = records[0].get("dataset")
737
+ if inferred:
738
+ self.query["dataset"] = inferred
739
+ else:
740
+ raise ValueError("You must provide a 'dataset' argument")
741
+ else:
742
+ raise ValueError("You must provide a 'dataset' argument")
743
+
744
+ # Decide on a dataset subfolder name for cache isolation. If using
745
+ # challenge/preprocessed buckets (e.g., BDF, mini subsets), append
746
+ # informative suffixes to avoid overlapping with the original dataset.
747
+ dataset_folder = self.query["dataset"]
748
+ if self.s3_bucket:
749
+ suffixes: list[str] = []
750
+ bucket_lower = str(self.s3_bucket).lower()
751
+ if "bdf" in bucket_lower:
752
+ suffixes.append("bdf")
753
+ if "mini" in bucket_lower:
754
+ suffixes.append("mini")
755
+ if suffixes:
756
+ dataset_folder = f"{dataset_folder}-{'-'.join(suffixes)}"
757
+
758
+ self.data_dir = self.cache_dir / dataset_folder
759
+
760
+ if (
761
+ not _suppress_comp_warning
762
+ and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
763
+ ):
764
+ warn(
765
+ "If you are not participating in the competition, you can ignore this warning!"
766
+ "\n\n"
767
+ "EEG 2025 Competition Data Notice:\n"
768
+ "---------------------------------\n"
769
+ " You are loading the dataset that is used in the EEG 2025 Competition:\n"
770
+ "IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n"
771
+ "and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
772
+ "to allow more people to participate.\n"
773
+ "\n"
774
+ "If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
775
+ "\n",
776
+ UserWarning,
777
+ module="eegdash",
778
+ )
779
+ if records is not None:
780
+ self.records = records
781
+ datasets = [
782
+ EEGDashBaseDataset(
783
+ record,
784
+ self.cache_dir,
785
+ self.s3_bucket,
786
+ **base_dataset_kwargs,
787
+ )
788
+ for record in self.records
789
+ ]
790
+ elif not download: # only assume local data is complete if not downloading
791
+ if not self.data_dir.exists():
792
+ raise ValueError(
793
+ f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
794
+ )
795
+ records = self._find_local_bids_records(self.data_dir, self.query)
796
+ # Try to enrich from local participants.tsv to restore requested fields
797
+ try:
798
+ bids_ds = EEGBIDSDataset(
799
+ data_dir=str(self.data_dir), dataset=self.query["dataset"]
800
+ ) # type: ignore[index]
801
+ except Exception:
802
+ bids_ds = None
803
+
804
+ datasets = []
805
+ for record in records:
806
+ # Start with entity values from filename
807
+ desc: dict[str, Any] = {
808
+ k: record.get(k)
809
+ for k in ("subject", "session", "run", "task")
810
+ if record.get(k) is not None
811
+ }
812
+
813
+ if bids_ds is not None:
814
+ try:
815
+ rel_from_dataset = Path(record["bidspath"]).relative_to(
816
+ record["dataset"]
817
+ ) # type: ignore[index]
818
+ local_file = (self.data_dir / rel_from_dataset).as_posix()
819
+ part_row = bids_ds.subject_participant_tsv(local_file)
820
+ desc = merge_participants_fields(
821
+ description=desc,
822
+ participants_row=part_row
823
+ if isinstance(part_row, dict)
824
+ else None,
825
+ description_fields=description_fields,
826
+ )
827
+ except Exception:
828
+ pass
785
829
 
786
- try:
787
- if records is not None:
788
- self.records = records
789
- datasets = [
830
+ datasets.append(
790
831
  EEGDashBaseDataset(
791
- record,
792
- self.cache_dir,
793
- self.s3_bucket,
794
- **base_dataset_kwargs,
795
- )
796
- for record in self.records
797
- ]
798
- elif data_dir:
799
- # This path loads from a local directory and is not affected by DB query logic
800
- if isinstance(data_dir, (str, Path)):
801
- datasets = self.load_bids_dataset(
802
- dataset=dataset
803
- if isinstance(dataset, str)
804
- else (dataset[0] if dataset else None),
805
- data_dir=data_dir,
806
- description_fields=description_fields,
807
- s3_bucket=s3_bucket,
832
+ record=record,
833
+ cache_dir=self.cache_dir,
834
+ s3_bucket=self.s3_bucket,
835
+ description=desc,
808
836
  **base_dataset_kwargs,
809
837
  )
810
- else:
811
- assert dataset is not None, (
812
- "dataset must be provided when passing multiple data_dir"
813
- )
814
- assert len(data_dir) == len(dataset), (
815
- "Number of datasets and directories must match"
816
- )
817
- datasets = []
818
- for i, _ in enumerate(data_dir):
819
- datasets.extend(
820
- self.load_bids_dataset(
821
- dataset=dataset[i],
822
- data_dir=data_dir[i],
823
- description_fields=description_fields,
824
- s3_bucket=s3_bucket,
825
- **base_dataset_kwargs,
826
- )
827
- )
828
- elif query is not None or query_kwargs:
829
- # This is the DB query path that we are improving
830
- datasets = self.find_datasets(
831
- query=query,
832
- description_fields=description_fields,
833
- query_kwargs=query_kwargs,
834
- base_dataset_kwargs=base_dataset_kwargs,
835
- )
836
- # We only need filesystem if we need to access S3
837
- self.filesystem = S3FileSystem(
838
- anon=True, client_kwargs={"region_name": "us-east-2"}
839
838
  )
840
- else:
841
- raise ValueError(
842
- "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
843
- )
844
- finally:
845
- if _owns_client and self.eeg_dash is not None:
846
- self.eeg_dash.close()
839
+ elif self.query:
840
+ # This is the DB query path that we are improving
841
+ datasets = self._find_datasets(
842
+ query=build_query_from_kwargs(**self.query),
843
+ description_fields=description_fields,
844
+ base_dataset_kwargs=base_dataset_kwargs,
845
+ )
846
+ # We only need filesystem if we need to access S3
847
+ self.filesystem = S3FileSystem(
848
+ anon=True, client_kwargs={"region_name": "us-east-2"}
849
+ )
850
+ else:
851
+ raise ValueError(
852
+ "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
853
+ )
847
854
 
848
855
  super().__init__(datasets)
849
856
 
850
- def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
851
- """Helper to recursively search for a key in a nested dictionary structure; returns
852
- the value associated with the first occurrence of the key, or None if not found.
857
+ def _find_local_bids_records(
858
+ self, dataset_root: Path, filters: dict[str, Any]
859
+ ) -> list[dict]:
860
+ """Discover local BIDS EEG files and build minimal records.
861
+
862
+ This helper enumerates EEG recordings under ``dataset_root`` via
863
+ ``mne_bids.find_matching_paths`` and applies entity filters to produce a
864
+ list of records suitable for ``EEGDashBaseDataset``. No network access
865
+ is performed and files are not read.
866
+
867
+ Parameters
868
+ ----------
869
+ dataset_root : Path
870
+ Local dataset directory. May be the plain dataset folder (e.g.,
871
+ ``ds005509``) or a suffixed cache variant (e.g.,
872
+ ``ds005509-bdf-mini``).
873
+ filters : dict of {str, Any}
874
+ Query filters. Must include ``'dataset'`` with the dataset id (without
875
+ local suffixes). May include BIDS entities ``'subject'``,
876
+ ``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar
877
+ or a sequence of scalars.
878
+
879
+ Returns
880
+ -------
881
+ records : list of dict
882
+ One record per matched EEG file with at least:
883
+
884
+ - ``'data_name'``
885
+ - ``'dataset'`` (dataset id, without suffixes)
886
+ - ``'bidspath'`` (normalized to start with the dataset id)
887
+ - ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None)
888
+ - ``'bidsdependencies'`` (empty list)
889
+ - ``'modality'`` (``"eeg"``)
890
+ - ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal
891
+ defaults for offline usage)
892
+
893
+ Notes
894
+ -----
895
+ - Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
896
+ - ``bidspath`` is constructed as
897
+ ``<dataset_id> / <relative_path_from_dataset_root>`` to ensure the
898
+ first path component is the dataset id (without local cache suffixes).
899
+ - Minimal defaults are set for ``sampling_frequency``, ``nchans``, and
900
+ ``ntimes`` to satisfy dataset length requirements offline.
901
+
853
902
  """
903
+ dataset_id = filters["dataset"]
904
+ arg_map = {
905
+ "subjects": "subject",
906
+ "sessions": "session",
907
+ "tasks": "task",
908
+ "runs": "run",
909
+ }
910
+ matching_args: dict[str, list[str]] = {}
911
+ for finder_key, entity_key in arg_map.items():
912
+ entity_val = filters.get(entity_key)
913
+ if entity_val is None:
914
+ continue
915
+ if isinstance(entity_val, (list, tuple, set)):
916
+ entity_vals = list(entity_val)
917
+ if not entity_vals:
918
+ continue
919
+ matching_args[finder_key] = entity_vals
920
+ else:
921
+ matching_args[finder_key] = [entity_val]
922
+
923
+ matched_paths = find_matching_paths(
924
+ root=str(dataset_root),
925
+ datatypes=["eeg"],
926
+ suffixes=["eeg"],
927
+ ignore_json=True,
928
+ **matching_args,
929
+ )
930
+ records_out: list[dict] = []
931
+
932
+ for bids_path in matched_paths:
933
+ # Build bidspath as dataset_id / relative_path_from_dataset_root (POSIX)
934
+ rel_from_root = (
935
+ Path(bids_path.fpath)
936
+ .resolve()
937
+ .relative_to(Path(bids_path.root).resolve())
938
+ )
939
+ bidspath = f"{dataset_id}/{rel_from_root.as_posix()}"
940
+
941
+ rec = {
942
+ "data_name": f"{dataset_id}_{Path(bids_path.fpath).name}",
943
+ "dataset": dataset_id,
944
+ "bidspath": bidspath,
945
+ "subject": (bids_path.subject or None),
946
+ "session": (bids_path.session or None),
947
+ "task": (bids_path.task or None),
948
+ "run": (bids_path.run or None),
949
+ # minimal fields to satisfy BaseDataset from eegdash
950
+ "bidsdependencies": [], # not needed to just run.
951
+ "modality": "eeg",
952
+ # minimal numeric defaults for offline length calculation
953
+ "sampling_frequency": None,
954
+ "nchans": None,
955
+ "ntimes": None,
956
+ }
957
+ records_out.append(rec)
958
+
959
+ return records_out
960
+
961
+ def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
962
+ """Recursively search for target_key in nested dicts/lists with normalized matching.
963
+
964
+ This makes lookups tolerant to naming differences like "p-factor" vs "p_factor".
965
+ Returns the first match or None.
966
+ """
967
+ norm_target = normalize_key(target_key)
854
968
  if isinstance(data, dict):
855
- if target_key in data:
856
- return data[target_key]
857
- for value in data.values():
858
- result = self.find_key_in_nested_dict(value, target_key)
859
- if result is not None:
860
- return result
969
+ for k, v in data.items():
970
+ if normalize_key(k) == norm_target:
971
+ return v
972
+ res = self._find_key_in_nested_dict(v, target_key)
973
+ if res is not None:
974
+ return res
975
+ elif isinstance(data, list):
976
+ for item in data:
977
+ res = self._find_key_in_nested_dict(item, target_key)
978
+ if res is not None:
979
+ return res
861
980
  return None
862
981
 
863
- def find_datasets(
982
+ def _find_datasets(
864
983
  self,
865
984
  query: dict[str, Any] | None,
866
985
  description_fields: list[str],
867
- query_kwargs: dict,
868
986
  base_dataset_kwargs: dict,
869
987
  ) -> list[EEGDashBaseDataset]:
870
988
  """Helper method to find datasets in the MongoDB collection that satisfy the
@@ -887,87 +1005,30 @@ class EEGDashDataset(BaseConcatDataset):
887
1005
 
888
1006
  """
889
1007
  datasets: list[EEGDashBaseDataset] = []
890
-
891
- # Build records using either a raw query OR keyword filters, but not both.
892
- # Note: callers may accidentally pass an empty dict for `query` along with
893
- # kwargs. In that case, treat it as if no query was provided and rely on kwargs.
894
- # Always delegate merging of raw query + kwargs to EEGDash.find
895
- self.records = self.eeg_dash.find(query, **query_kwargs)
1008
+ self.records = self.eeg_dash_instance.find(query)
896
1009
 
897
1010
  for record in self.records:
898
- description = {}
1011
+ description: dict[str, Any] = {}
1012
+ # Requested fields first (normalized matching)
899
1013
  for field in description_fields:
900
- value = self.find_key_in_nested_dict(record, field)
1014
+ value = self._find_key_in_nested_dict(record, field)
901
1015
  if value is not None:
902
1016
  description[field] = value
1017
+ # Merge all participants.tsv columns generically
1018
+ part = self._find_key_in_nested_dict(record, "participant_tsv")
1019
+ if isinstance(part, dict):
1020
+ description = merge_participants_fields(
1021
+ description=description,
1022
+ participants_row=part,
1023
+ description_fields=description_fields,
1024
+ )
903
1025
  datasets.append(
904
1026
  EEGDashBaseDataset(
905
1027
  record,
906
- self.cache_dir,
907
- self.s3_bucket,
1028
+ cache_dir=self.cache_dir,
1029
+ s3_bucket=self.s3_bucket,
908
1030
  description=description,
909
1031
  **base_dataset_kwargs,
910
1032
  )
911
1033
  )
912
1034
  return datasets
913
-
914
- def load_bids_dataset(
915
- self,
916
- dataset: str,
917
- data_dir: str | Path,
918
- description_fields: list[str],
919
- s3_bucket: str | None = None,
920
- **kwargs,
921
- ):
922
- """Helper method to load a single local BIDS dataset and return it as a list of
923
- EEGDashBaseDatasets (one for each recording in the dataset).
924
-
925
- Parameters
926
- ----------
927
- dataset : str
928
- A name for the dataset to be loaded (e.g., "ds002718").
929
- data_dir : str
930
- The path to the local BIDS dataset directory.
931
- description_fields : list[str]
932
- A list of fields to be extracted from the dataset records
933
- and included in the returned dataset description(s).
934
-
935
- """
936
- bids_dataset = EEGBIDSDataset(
937
- data_dir=data_dir,
938
- dataset=dataset,
939
- )
940
- datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
941
- delayed(self.get_base_dataset_from_bids_file)(
942
- bids_dataset=bids_dataset,
943
- bids_file=bids_file,
944
- s3_bucket=s3_bucket,
945
- description_fields=description_fields,
946
- **kwargs,
947
- )
948
- for bids_file in bids_dataset.get_files()
949
- )
950
- return datasets
951
-
952
- def get_base_dataset_from_bids_file(
953
- self,
954
- bids_dataset: "EEGBIDSDataset",
955
- bids_file: str,
956
- s3_bucket: str | None,
957
- description_fields: list[str],
958
- **kwargs,
959
- ) -> "EEGDashBaseDataset":
960
- """Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
961
- record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
962
- description = {}
963
- for field in description_fields:
964
- value = self.find_key_in_nested_dict(record, field)
965
- if value is not None:
966
- description[field] = value
967
- return EEGDashBaseDataset(
968
- record,
969
- self.cache_dir,
970
- s3_bucket,
971
- description=description,
972
- **kwargs,
973
- )