eegdash 0.1.0__py3-none-any.whl → 0.2.1.dev178237806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/api.py ADDED
@@ -0,0 +1,717 @@
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Any, Mapping
6
+
7
+ import mne
8
+ import numpy as np
9
+ import xarray as xr
10
+ from dotenv import load_dotenv
11
+ from joblib import Parallel, delayed
12
+ from pymongo import InsertOne, UpdateOne
13
+ from s3fs import S3FileSystem
14
+
15
+ from braindecode.datasets import BaseConcatDataset
16
+
17
+ from .data_config import config as data_config
18
+ from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
19
+ from .mongodb import MongoConnectionManager
20
+
21
+ logger = logging.getLogger("eegdash")
22
+
23
+
24
+ class EEGDash:
25
+ """A high-level interface to the EEGDash database.
26
+
27
+ This class is primarily used to interact with the metadata records stored in the
28
+ EEGDash database (or a private instance of it), allowing users to find, add, and
29
+ update EEG data records.
30
+
31
+ While this class provides basic support for loading EEG data, please see
32
+ the EEGDashDataset class for a more complete way to retrieve and work with full
33
+ datasets.
34
+
35
+ """
36
+
37
+ AWS_BUCKET = "s3://openneuro.org"
38
+
39
+ def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
40
+ """Create new instance of the EEGDash Database client.
41
+
42
+ Parameters
43
+ ----------
44
+ is_public: bool
45
+ Whether to connect to the public MongoDB database; if False, connect to a
46
+ private database instance as per the DB_CONNECTION_STRING env variable
47
+ (or .env file entry).
48
+ is_staging: bool
49
+ If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
50
+ production database ("eegdash").
51
+
52
+ Example
53
+ -------
54
+ >>> eegdash = EEGDash()
55
+
56
+ """
57
+ self.config = data_config
58
+ self.is_public = is_public
59
+ self.is_staging = is_staging
60
+
61
+ if self.is_public:
62
+ DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
63
+ else:
64
+ load_dotenv()
65
+ DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
66
+
67
+ # Use singleton to get MongoDB client, database, and collection
68
+ self.__client, self.__db, self.__collection = MongoConnectionManager.get_client(
69
+ DB_CONNECTION_STRING, is_staging
70
+ )
71
+
72
+ self.filesystem = S3FileSystem(
73
+ anon=True, client_kwargs={"region_name": "us-east-2"}
74
+ )
75
+
76
+ def find(self, query: dict[str, Any], *args, **kwargs) -> list[Mapping[str, Any]]:
77
+ """Find records in the MongoDB collection that satisfy the given query.
78
+
79
+ Parameters
80
+ ----------
81
+ query: dict
82
+ A dictionary that specifies the query to be executed; this is a reference
83
+ document that is used to match records in the MongoDB collection.
84
+ args:
85
+ Additional positional arguments for the MongoDB find() method; see
86
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
87
+ kwargs:
88
+ Additional keyword arguments for the MongoDB find() method.
89
+
90
+ Returns
91
+ -------
92
+ list:
93
+ A list of DB records (string-keyed dictionaries) that match the query.
94
+
95
+ Example
96
+ -------
97
+ >>> eegdash = EEGDash()
98
+ >>> eegdash.find({"dataset": "ds002718", "subject": "012"})
99
+
100
+ """
101
+ results = self.__collection.find(query, *args, **kwargs)
102
+
103
+ return [result for result in results]
104
+
105
+ def exist(self, query: dict[str, Any]) -> bool:
106
+ """Return True if at least one record matches the query, else False.
107
+
108
+ This is a lightweight existence check that uses MongoDB's ``find_one``
109
+ instead of fetching all matching documents (which would be wasteful in
110
+ both time and memory for broad queries). Only a restricted set of
111
+ fields is accepted to avoid accidental full scans caused by malformed
112
+ or unsupported keys.
113
+
114
+ Parameters
115
+ ----------
116
+ query : dict
117
+ Mapping of allowed field(s) to value(s). Allowed keys: ``data_name``
118
+ and ``dataset``. The query must not be empty.
119
+
120
+ Returns
121
+ -------
122
+ bool
123
+ True if at least one matching record exists; False otherwise.
124
+
125
+ Raises
126
+ ------
127
+ TypeError
128
+ If ``query`` is not a dict.
129
+ ValueError
130
+ If ``query`` is empty or contains unsupported field names.
131
+
132
+ """
133
+ if not isinstance(query, dict):
134
+ raise TypeError("query must be a dict")
135
+ if not query:
136
+ raise ValueError("query cannot be empty")
137
+
138
+ accepted_query_fields = {"data_name", "dataset"}
139
+ unknown = set(query.keys()) - accepted_query_fields
140
+ if unknown:
141
+ raise ValueError(
142
+ f"Unsupported query field(s): {', '.join(sorted(unknown))}. "
143
+ f"Allowed: {sorted(accepted_query_fields)}"
144
+ )
145
+
146
+ doc = self.__collection.find_one(query, projection={"_id": 1})
147
+ return doc is not None
148
+
149
+ def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
150
+ """Internal method to validate the input record against the expected schema.
151
+
152
+ Parameters
153
+ ----------
154
+ record: dict
155
+ A dictionary representing the EEG data record to be validated.
156
+
157
+ Returns
158
+ -------
159
+ dict:
160
+ Returns the record itself on success, or raises a ValueError if the record is invalid.
161
+
162
+ """
163
+ input_types = {
164
+ "data_name": str,
165
+ "dataset": str,
166
+ "bidspath": str,
167
+ "subject": str,
168
+ "task": str,
169
+ "session": str,
170
+ "run": str,
171
+ "sampling_frequency": float,
172
+ "modality": str,
173
+ "nchans": int,
174
+ "ntimes": int,
175
+ "channel_types": list,
176
+ "channel_names": list,
177
+ }
178
+ if "data_name" not in record:
179
+ raise ValueError("Missing key: data_name")
180
+ # check if args are in the keys and has correct type
181
+ for key, value in record.items():
182
+ if key not in input_types:
183
+ raise ValueError(f"Invalid input: {key}")
184
+ if not isinstance(value, input_types[key]):
185
+ raise ValueError(f"Invalid input: {key}")
186
+
187
+ return record
188
+
189
+ def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
190
+ """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
191
+
192
+ Parameters
193
+ ----------
194
+ s3path : str
195
+ An S3 URI (should start with "s3://") for the file in question.
196
+
197
+ Returns
198
+ -------
199
+ xr.DataArray
200
+ A DataArray containing the EEG data, with dimensions "channel" and "time".
201
+
202
+ Example
203
+ -------
204
+ >>> eegdash = EEGDash()
205
+ >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
206
+ >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
207
+
208
+ """
209
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
210
+ with self.filesystem.open(s3path) as s3_file:
211
+ tmp.write(s3_file.read())
212
+ tmp_path = tmp.name
213
+ eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
214
+ os.unlink(tmp_path)
215
+ return eeg_data
216
+
217
+ def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
218
+ """Load EEG data from a local file and return it as a xarray DataArray.
219
+
220
+ Parameters
221
+ ----------
222
+ bids_file : str
223
+ Path to the file on the local filesystem.
224
+
225
+ Notes
226
+ -----
227
+ Currently, only non-epoched .set files are supported.
228
+
229
+ """
230
+ raw_object = mne.io.read_raw(bids_file)
231
+ eeg_data = raw_object.get_data()
232
+
233
+ fs = raw_object.info["sfreq"]
234
+ max_time = eeg_data.shape[1] / fs
235
+ time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
236
+
237
+ channel_names = raw_object.ch_names
238
+
239
+ eeg_xarray = xr.DataArray(
240
+ data=eeg_data,
241
+ dims=["channel", "time"],
242
+ coords={"time": time_steps, "channel": channel_names},
243
+ )
244
+ return eeg_xarray
245
+
246
+ def get_raw_extensions(
247
+ self, bids_file: str, bids_dataset: EEGBIDSDataset
248
+ ) -> list[str]:
249
+ """Helper to find paths to additional "sidecar" files that may be associated
250
+ with a given main data file in a BIDS dataset; paths are returned as relative to
251
+ the parent dataset path.
252
+
253
+ For example, if the input file is a .set file, this will return the relative path
254
+ to a corresponding .fdt file (if any).
255
+ """
256
+ bids_file = Path(bids_file)
257
+ extensions = {
258
+ ".set": [".set", ".fdt"], # eeglab
259
+ ".edf": [".edf"], # european
260
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
261
+ ".bdf": [".bdf"], # biosemi
262
+ }
263
+ return [
264
+ str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
265
+ for suffix in extensions[bids_file.suffix]
266
+ if bids_file.with_suffix(suffix).exists()
267
+ ]
268
+
269
+ def load_eeg_attrs_from_bids_file(
270
+ self, bids_dataset: EEGBIDSDataset, bids_file: str
271
+ ) -> dict[str, Any]:
272
+ """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
273
+
274
+ Attributes are at least the ones defined in data_config attributes (set to None if missing),
275
+ but are typically a superset, and include, among others, the paths to relevant
276
+ meta-data files needed to load and interpret the file in question.
277
+
278
+ Parameters
279
+ ----------
280
+ bids_dataset : EEGBIDSDataset
281
+ The BIDS dataset object containing the file.
282
+ bids_file : str
283
+ The path to the BIDS file within the dataset.
284
+
285
+ Returns
286
+ -------
287
+ dict:
288
+ A dictionary representing the metadata record for the given file. This is the
289
+ same format as the records stored in the database.
290
+
291
+ """
292
+ if bids_file not in bids_dataset.files:
293
+ raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
294
+
295
+ # Initialize attrs with None values for all expected fields
296
+ attrs = {field: None for field in self.config["attributes"].keys()}
297
+
298
+ file = Path(bids_file).name
299
+ dsnumber = bids_dataset.dataset
300
+ # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
301
+ openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
302
+
303
+ # Update with actual values where available
304
+ try:
305
+ participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
306
+ except Exception as e:
307
+ logger.error("Error getting participants_tsv: %s", str(e))
308
+ participants_tsv = None
309
+
310
+ try:
311
+ eeg_json = bids_dataset.eeg_json(bids_file)
312
+ except Exception as e:
313
+ logger.error("Error getting eeg_json: %s", str(e))
314
+ eeg_json = None
315
+
316
+ bids_dependencies_files = self.config["bids_dependencies_files"]
317
+ bidsdependencies = []
318
+ for extension in bids_dependencies_files:
319
+ try:
320
+ dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
321
+ dep_path = [
322
+ str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
323
+ ]
324
+ bidsdependencies.extend(dep_path)
325
+ except Exception:
326
+ pass
327
+
328
+ bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
329
+
330
+ # Define field extraction functions with error handling
331
+ field_extractors = {
332
+ "data_name": lambda: f"{bids_dataset.dataset}_{file}",
333
+ "dataset": lambda: bids_dataset.dataset,
334
+ "bidspath": lambda: openneuro_path,
335
+ "subject": lambda: bids_dataset.get_bids_file_attribute(
336
+ "subject", bids_file
337
+ ),
338
+ "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
339
+ "session": lambda: bids_dataset.get_bids_file_attribute(
340
+ "session", bids_file
341
+ ),
342
+ "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
343
+ "modality": lambda: bids_dataset.get_bids_file_attribute(
344
+ "modality", bids_file
345
+ ),
346
+ "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
347
+ "sfreq", bids_file
348
+ ),
349
+ "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
350
+ "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
351
+ "participant_tsv": lambda: participants_tsv,
352
+ "eeg_json": lambda: eeg_json,
353
+ "bidsdependencies": lambda: bidsdependencies,
354
+ }
355
+
356
+ # Dynamically populate attrs with error handling
357
+ for field, extractor in field_extractors.items():
358
+ try:
359
+ attrs[field] = extractor()
360
+ except Exception as e:
361
+ logger.error("Error extracting %s : %s", field, str(e))
362
+ attrs[field] = None
363
+
364
+ return attrs
365
+
366
+ def add_bids_dataset(
367
+ self, dataset: str, data_dir: str, overwrite: bool = True
368
+ ) -> None:
369
+ """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
370
+ under the given dataset name.
371
+
372
+ Parameters
373
+ ----------
374
+ dataset : str)
375
+ The name of the dataset to be added (e.g., "ds002718").
376
+ data_dir : str
377
+ The path to the BIDS dataset directory.
378
+ overwrite : bool
379
+ Whether to overwrite/update existing records in the database.
380
+
381
+ """
382
+ if self.is_public:
383
+ raise ValueError("This operation is not allowed for public users")
384
+
385
+ if not overwrite and self.exist({"dataset": dataset}):
386
+ logger.info("Dataset %s already exists in the database", dataset)
387
+ return
388
+ try:
389
+ bids_dataset = EEGBIDSDataset(
390
+ data_dir=data_dir,
391
+ dataset=dataset,
392
+ )
393
+ except Exception as e:
394
+ logger.error("Error creating bids dataset %s: $s", dataset, str(e))
395
+ raise e
396
+ requests = []
397
+ for bids_file in bids_dataset.get_files():
398
+ try:
399
+ data_id = f"{dataset}_{Path(bids_file).name}"
400
+
401
+ if self.exist({"data_name": data_id}):
402
+ if overwrite:
403
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
404
+ bids_dataset, bids_file
405
+ )
406
+ requests.append(self.update_request(eeg_attrs))
407
+ else:
408
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
409
+ bids_dataset, bids_file
410
+ )
411
+ requests.append(self.add_request(eeg_attrs))
412
+ except Exception as e:
413
+ logger.error("Error adding record %s", bids_file)
414
+ logger.error(str(e))
415
+
416
+ logger.info("Number of requests: %s", len(requests))
417
+
418
+ if requests:
419
+ result = self.__collection.bulk_write(requests, ordered=False)
420
+ logger.info("Inserted: %s ", result.inserted_count)
421
+ logger.info("Modified: %s ", result.modified_count)
422
+ logger.info("Deleted: %s", result.deleted_count)
423
+ logger.info("Upserted: %s", result.upserted_count)
424
+ logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
425
+
426
+ def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
427
+ """Retrieve a list of EEG data arrays that match the given query. See also
428
+ the `find()` method for details on the query format.
429
+
430
+ Parameters
431
+ ----------
432
+ query : dict
433
+ A dictionary that specifies the query to be executed; this is a reference
434
+ document that is used to match records in the MongoDB collection.
435
+
436
+ Returns
437
+ -------
438
+ A list of xarray DataArray objects containing the EEG data for each matching record.
439
+
440
+ Notes
441
+ -----
442
+ Retrieval is done in parallel, and the downloaded data are not cached locally.
443
+
444
+ """
445
+ sessions = self.find(query)
446
+ results = []
447
+ if sessions:
448
+ logger.info("Found %s records", len(sessions))
449
+ results = Parallel(
450
+ n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
451
+ )(
452
+ delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
453
+ for session in sessions
454
+ )
455
+ return results
456
+
457
+ def add_request(self, record: dict):
458
+ """Internal helper method to create a MongoDB insertion request for a record."""
459
+ return InsertOne(record)
460
+
461
+ def add(self, record: dict):
462
+ """Add a single record to the MongoDB collection."""
463
+ try:
464
+ self.__collection.insert_one(record)
465
+ except ValueError as e:
466
+ logger.error("Validation error for record: %s ", record["data_name"])
467
+ logger.error(e)
468
+ except:
469
+ logger.error("Error adding record: %s ", record["data_name"])
470
+
471
+ def update_request(self, record: dict):
472
+ """Internal helper method to create a MongoDB update request for a record."""
473
+ return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
474
+
475
+ def update(self, record: dict):
476
+ """Update a single record in the MongoDB collection."""
477
+ try:
478
+ self.__collection.update_one(
479
+ {"data_name": record["data_name"]}, {"$set": record}
480
+ )
481
+ except: # silent failure
482
+ logger.error("Error updating record: %s", record["data_name"])
483
+
484
+ def remove_field(self, record, field):
485
+ """Remove a specific field from a record in the MongoDB collection."""
486
+ self.__collection.update_one(
487
+ {"data_name": record["data_name"]}, {"$unset": {field: 1}}
488
+ )
489
+
490
+ def remove_field_from_db(self, field):
491
+ """Removed all occurrences of a specific field from all records in the MongoDB
492
+ collection. WARNING: this operation is destructive and should be used with caution.
493
+ """
494
+ self.__collection.update_many({}, {"$unset": {field: 1}})
495
+
496
+ @property
497
+ def collection(self):
498
+ """Return the MongoDB collection object."""
499
+ return self.__collection
500
+
501
+ def close(self):
502
+ """Close the MongoDB client connection.
503
+
504
+ Note: Since MongoDB clients are now managed by a singleton,
505
+ this method no longer closes connections. Use close_all_connections()
506
+ class method to close all connections if needed.
507
+ """
508
+ # Individual instances no longer close the shared client
509
+ pass
510
+
511
+ @classmethod
512
+ def close_all_connections(cls):
513
+ """Close all MongoDB client connections managed by the singleton."""
514
+ MongoConnectionManager.close_all()
515
+
516
+ def __del__(self):
517
+ """Ensure connection is closed when object is deleted."""
518
+ # No longer needed since we're using singleton pattern
519
+ pass
520
+
521
+
522
+ class EEGDashDataset(BaseConcatDataset):
523
+ def __init__(
524
+ self,
525
+ query: dict | None = None,
526
+ data_dir: str | list | None = None,
527
+ dataset: str | list | None = None,
528
+ description_fields: list[str] = [
529
+ "subject",
530
+ "session",
531
+ "run",
532
+ "task",
533
+ "age",
534
+ "gender",
535
+ "sex",
536
+ ],
537
+ cache_dir: str = ".eegdash_cache",
538
+ s3_bucket: str | None = None,
539
+ **kwargs,
540
+ ):
541
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
542
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
543
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
544
+
545
+ Parameters
546
+ ----------
547
+ query : dict | None
548
+ Optionally a dictionary that specifies the query to be executed; see
549
+ EEGDash.find() for details on the query format.
550
+ data_dir : str | list[str] | None
551
+ Optionally a string or a list of strings specifying one or more local
552
+ BIDS dataset directories from which to load the EEG data files. Exactly one
553
+ of query or data_dir must be provided.
554
+ dataset : str | list[str] | None
555
+ If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
556
+ description_fields : list[str]
557
+ A list of fields to be extracted from the dataset records
558
+ and included in the returned data description(s). Examples are typical
559
+ subject metadata fields such as "subject", "session", "run", "task", etc.;
560
+ see also data_config.description_fields for the default set of fields.
561
+ cache_dir : str
562
+ A directory where the dataset will be cached locally.
563
+ s3_bucket : str | None
564
+ An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
565
+ default OpenNeuro bucket for loading data files
566
+ kwargs : dict
567
+ Additional keyword arguments to be passed to the EEGDashBaseDataset
568
+ constructor.
569
+
570
+ """
571
+ self.cache_dir = cache_dir
572
+ self.s3_bucket = s3_bucket
573
+ if query:
574
+ datasets = self.find_datasets(query, description_fields, **kwargs)
575
+ elif data_dir:
576
+ if isinstance(data_dir, str):
577
+ datasets = self.load_bids_dataset(
578
+ dataset, data_dir, description_fields, s3_bucket
579
+ )
580
+ else:
581
+ assert len(data_dir) == len(dataset), (
582
+ "Number of datasets and their directories must match"
583
+ )
584
+ datasets = []
585
+ for i, _ in enumerate(data_dir):
586
+ datasets.extend(
587
+ self.load_bids_dataset(
588
+ dataset[i], data_dir[i], description_fields, s3_bucket
589
+ )
590
+ )
591
+
592
+ super().__init__(datasets)
593
+
594
+ def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
595
+ """Helper to recursively search for a key in a nested dictionary structure; returns
596
+ the value associated with the first occurrence of the key, or None if not found.
597
+ """
598
+ if isinstance(data, dict):
599
+ if target_key in data:
600
+ return data[target_key]
601
+ for value in data.values():
602
+ result = self.find_key_in_nested_dict(value, target_key)
603
+ if result is not None:
604
+ return result
605
+ return None
606
+
607
+ def find_datasets(
608
+ self, query: dict[str, Any], description_fields: list[str], **kwargs
609
+ ) -> list[EEGDashBaseDataset]:
610
+ """Helper method to find datasets in the MongoDB collection that satisfy the
611
+ given query and return them as a list of EEGDashBaseDataset objects.
612
+
613
+ Parameters
614
+ ----------
615
+ query : dict
616
+ The query object, as in EEGDash.find().
617
+ description_fields : list[str]
618
+ A list of fields to be extracted from the dataset records and included in
619
+ the returned dataset description(s).
620
+ kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
621
+ constructor.
622
+
623
+ Returns
624
+ -------
625
+ list :
626
+ A list of EEGDashBaseDataset objects that match the query.
627
+
628
+ """
629
+ eeg_dash_instance = EEGDash()
630
+ try:
631
+ datasets = []
632
+ for record in eeg_dash_instance.find(query):
633
+ description = {}
634
+ for field in description_fields:
635
+ value = self.find_key_in_nested_dict(record, field)
636
+ if value is not None:
637
+ description[field] = value
638
+ datasets.append(
639
+ EEGDashBaseDataset(
640
+ record,
641
+ self.cache_dir,
642
+ self.s3_bucket,
643
+ description=description,
644
+ **kwargs,
645
+ )
646
+ )
647
+ return datasets
648
+ finally:
649
+ eeg_dash_instance.close()
650
+
651
+ def load_bids_dataset(
652
+ self,
653
+ dataset,
654
+ data_dir,
655
+ description_fields: list[str],
656
+ s3_bucket: str | None = None,
657
+ **kwargs,
658
+ ):
659
+ """Helper method to load a single local BIDS dataset and return it as a list of
660
+ EEGDashBaseDatasets (one for each recording in the dataset).
661
+
662
+ Parameters
663
+ ----------
664
+ dataset : str
665
+ A name for the dataset to be loaded (e.g., "ds002718").
666
+ data_dir : str
667
+ The path to the local BIDS dataset directory.
668
+ description_fields : list[str]
669
+ A list of fields to be extracted from the dataset records
670
+ and included in the returned dataset description(s).
671
+
672
+ """
673
+ bids_dataset = EEGBIDSDataset(
674
+ data_dir=data_dir,
675
+ dataset=dataset,
676
+ )
677
+ eeg_dash_instance = EEGDash()
678
+ try:
679
+ datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
680
+ delayed(self.get_base_dataset_from_bids_file)(
681
+ bids_dataset=bids_dataset,
682
+ bids_file=bids_file,
683
+ eeg_dash_instance=eeg_dash_instance,
684
+ s3_bucket=s3_bucket,
685
+ description_fields=description_fields,
686
+ )
687
+ for bids_file in bids_dataset.get_files()
688
+ )
689
+ return datasets
690
+ finally:
691
+ eeg_dash_instance.close()
692
+
693
+ def get_base_dataset_from_bids_file(
694
+ self,
695
+ bids_dataset: EEGBIDSDataset,
696
+ bids_file: str,
697
+ eeg_dash_instance: EEGDash,
698
+ s3_bucket: str | None,
699
+ description_fields: list[str],
700
+ ) -> EEGDashBaseDataset:
701
+ """Instantiate a single EEGDashBaseDataset given a local BIDS file. Note
702
+ this does not actually load the data from disk, but will access the metadata.
703
+ """
704
+ record = eeg_dash_instance.load_eeg_attrs_from_bids_file(
705
+ bids_dataset, bids_file
706
+ )
707
+ description = {}
708
+ for field in description_fields:
709
+ value = self.find_key_in_nested_dict(record, field)
710
+ if value is not None:
711
+ description[field] = value
712
+ return EEGDashBaseDataset(
713
+ record,
714
+ self.cache_dir,
715
+ s3_bucket,
716
+ description=description,
717
+ )