eegdash 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- from .main import EEGDash, EEGDashDataset
1
+ from .api import EEGDash, EEGDashDataset
2
+ from .dataset import EEGChallengeDataset
3
+ from .utils import __init__mongo_client
2
4
 
3
- __all__ = ["EEGDash", "EEGDashDataset"]
4
- __version__ = "0.1.0"
5
+ __init__mongo_client()
6
+
7
+ __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset"]
8
+ __version__ = "0.2.0"
eegdash/api.py ADDED
@@ -0,0 +1,690 @@
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Any, Mapping
6
+
7
+ import mne
8
+ import numpy as np
9
+ import xarray as xr
10
+ from dotenv import load_dotenv
11
+ from joblib import Parallel, delayed
12
+ from pymongo import InsertOne, MongoClient, UpdateOne
13
+ from s3fs import S3FileSystem
14
+
15
+ from braindecode.datasets import BaseConcatDataset
16
+
17
+ from .data_config import config as data_config
18
+ from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
19
+
20
+ logger = logging.getLogger("eegdash")
21
+
22
+
23
+ class EEGDash:
24
+ """A high-level interface to the EEGDash database.
25
+
26
+ This class is primarily used to interact with the metadata records stored in the
27
+ EEGDash database (or a private instance of it), allowing users to find, add, and
28
+ update EEG data records.
29
+
30
+ While this class provides basic support for loading EEG data, please see
31
+ the EEGDashDataset class for a more complete way to retrieve and work with full
32
+ datasets.
33
+
34
+ """
35
+
36
+ AWS_BUCKET = "s3://openneuro.org"
37
+
38
+ def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
39
+ """Create new instance of the EEGDash Database client.
40
+
41
+ Parameters
42
+ ----------
43
+ is_public: bool
44
+ Whether to connect to the public MongoDB database; if False, connect to a
45
+ private database instance as per the DB_CONNECTION_STRING env variable
46
+ (or .env file entry).
47
+ is_staging: bool
48
+ If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
49
+ production database ("eegdash").
50
+
51
+ Example
52
+ -------
53
+ >>> eegdash = EEGDash()
54
+
55
+ """
56
+ self.config = data_config
57
+ self.is_public = is_public
58
+
59
+ if self.is_public:
60
+ DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
61
+ else:
62
+ load_dotenv()
63
+ DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
64
+
65
+ self.__client = MongoClient(DB_CONNECTION_STRING)
66
+ self.__db = (
67
+ self.__client["eegdash"]
68
+ if not is_staging
69
+ else self.__client["eegdashstaging"]
70
+ )
71
+ self.__collection = self.__db["records"]
72
+
73
+ self.filesystem = S3FileSystem(
74
+ anon=True, client_kwargs={"region_name": "us-east-2"}
75
+ )
76
+
77
+ # MongoDB Operations
78
+ # These methods provide a high-level interface to interact with the MongoDB
79
+ # collection, allowing users to find, add, and update EEG data records.
80
+ # - find:
81
+ # - exist:
82
+ # - add_request:
83
+ # - add:
84
+ # - update_request:
85
+ # - remove_field:
86
+ # - remove_field_from_db:
87
+ # - close: Close the MongoDB connection.
88
+ # - __del__: Destructor to close the MongoDB connection.
89
+
90
+ def find(self, query: dict[str, Any], *args, **kwargs) -> list[Mapping[str, Any]]:
91
+ """Find records in the MongoDB collection that satisfy the given query.
92
+
93
+ Parameters
94
+ ----------
95
+ query: dict
96
+ A dictionary that specifies the query to be executed; this is a reference
97
+ document that is used to match records in the MongoDB collection.
98
+ args:
99
+ Additional positional arguments for the MongoDB find() method; see
100
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
101
+ kwargs:
102
+ Additional keyword arguments for the MongoDB find() method.
103
+
104
+ Returns
105
+ -------
106
+ list:
107
+ A list of DB records (string-keyed dictionaries) that match the query.
108
+
109
+ Example
110
+ -------
111
+ >>> eegdash = EEGDash()
112
+ >>> eegdash.find({"dataset": "ds002718", "subject": "012"})
113
+
114
+ """
115
+ results = self.__collection.find(query, *args, **kwargs)
116
+
117
+ return [result for result in results]
118
+
119
+ def exist(self, query: dict[str, Any]) -> bool:
120
+ """Check if the given query matches any records in the MongoDB collection.
121
+
122
+ Note that currently only a limited set of query fields is allowed here.
123
+
124
+ Parameters
125
+ ----------
126
+ query: dict
127
+ A dictionary that specifies the query to be executed; this is a reference
128
+ document that is used to match records in the MongoDB collection.
129
+
130
+ Returns
131
+ -------
132
+ bool:
133
+ True if at least one record matches the query, False otherwise.
134
+
135
+ """
136
+ accepted_query_fields = ["data_name", "dataset"]
137
+ assert all(field in accepted_query_fields for field in query.keys())
138
+ sessions = self.find(query)
139
+ return len(sessions) > 0
140
+
141
+ def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
142
+ """Internal method to validate the input record against the expected schema.
143
+
144
+ Parameters
145
+ ----------
146
+ record: dict
147
+ A dictionary representing the EEG data record to be validated.
148
+
149
+ Returns
150
+ -------
151
+ dict:
152
+ Returns the record itself on success, or raises a ValueError if the record is invalid.
153
+
154
+ """
155
+ input_types = {
156
+ "data_name": str,
157
+ "dataset": str,
158
+ "bidspath": str,
159
+ "subject": str,
160
+ "task": str,
161
+ "session": str,
162
+ "run": str,
163
+ "sampling_frequency": float,
164
+ "modality": str,
165
+ "nchans": int,
166
+ "ntimes": int,
167
+ "channel_types": list,
168
+ "channel_names": list,
169
+ }
170
+ if "data_name" not in record:
171
+ raise ValueError("Missing key: data_name")
172
+ # check if args are in the keys and has correct type
173
+ for key, value in record.items():
174
+ if key not in input_types:
175
+ raise ValueError(f"Invalid input: {key}")
176
+ if not isinstance(value, input_types[key]):
177
+ raise ValueError(f"Invalid input: {key}")
178
+
179
+ return record
180
+
181
+ def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
182
+ """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
183
+
184
+ Parameters
185
+ ----------
186
+ s3path : str
187
+ An S3 URI (should start with "s3://") for the file in question.
188
+
189
+ Returns
190
+ -------
191
+ xr.DataArray
192
+ A DataArray containing the EEG data, with dimensions "channel" and "time".
193
+
194
+ Example
195
+ -------
196
+ >>> eegdash = EEGDash()
197
+ >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
198
+ >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
199
+
200
+ """
201
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
202
+ with self.filesystem.open(s3path) as s3_file:
203
+ tmp.write(s3_file.read())
204
+ tmp_path = tmp.name
205
+ eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
206
+ os.unlink(tmp_path)
207
+ return eeg_data
208
+
209
+ def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
210
+ """Load EEG data from a local file and return it as a xarray DataArray.
211
+
212
+ Parameters
213
+ ----------
214
+ bids_file : str
215
+ Path to the file on the local filesystem.
216
+
217
+ Notes
218
+ -----
219
+ Currently, only non-epoched .set files are supported.
220
+
221
+ """
222
+ raw_object = mne.io.read_raw(bids_file)
223
+ eeg_data = raw_object.get_data()
224
+
225
+ fs = raw_object.info["sfreq"]
226
+ max_time = eeg_data.shape[1] / fs
227
+ time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
228
+
229
+ channel_names = raw_object.ch_names
230
+
231
+ eeg_xarray = xr.DataArray(
232
+ data=eeg_data,
233
+ dims=["channel", "time"],
234
+ coords={"time": time_steps, "channel": channel_names},
235
+ )
236
+ return eeg_xarray
237
+
238
+ def get_raw_extensions(
239
+ self, bids_file: str, bids_dataset: EEGBIDSDataset
240
+ ) -> list[str]:
241
+ """Helper to find paths to additional "sidecar" files that may be associated
242
+ with a given main data file in a BIDS dataset; paths are returned as relative to
243
+ the parent dataset path.
244
+
245
+ For example, if the input file is a .set file, this will return the relative path
246
+ to a corresponding .fdt file (if any).
247
+ """
248
+ bids_file = Path(bids_file)
249
+ extensions = {
250
+ ".set": [".set", ".fdt"], # eeglab
251
+ ".edf": [".edf"], # european
252
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
253
+ ".bdf": [".bdf"], # biosemi
254
+ }
255
+ return [
256
+ str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
257
+ for suffix in extensions[bids_file.suffix]
258
+ if bids_file.with_suffix(suffix).exists()
259
+ ]
260
+
261
+ def load_eeg_attrs_from_bids_file(
262
+ self, bids_dataset: EEGBIDSDataset, bids_file: str
263
+ ) -> dict[str, Any]:
264
+ """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
265
+
266
+ Attributes are at least the ones defined in data_config attributes (set to None if missing),
267
+ but are typically a superset, and include, among others, the paths to relevant
268
+ meta-data files needed to load and interpret the file in question.
269
+
270
+ Parameters
271
+ ----------
272
+ bids_dataset : EEGBIDSDataset
273
+ The BIDS dataset object containing the file.
274
+ bids_file : str
275
+ The path to the BIDS file within the dataset.
276
+
277
+ Returns
278
+ -------
279
+ dict:
280
+ A dictionary representing the metadata record for the given file. This is the
281
+ same format as the records stored in the database.
282
+
283
+ """
284
+ if bids_file not in bids_dataset.files:
285
+ raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
286
+
287
+ # Initialize attrs with None values for all expected fields
288
+ attrs = {field: None for field in self.config["attributes"].keys()}
289
+
290
+ file = Path(bids_file).name
291
+ dsnumber = bids_dataset.dataset
292
+ # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
293
+ openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
294
+
295
+ # Update with actual values where available
296
+ try:
297
+ participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
298
+ except Exception as e:
299
+ logger.error("Error getting participants_tsv: %s", str(e))
300
+ participants_tsv = None
301
+
302
+ try:
303
+ eeg_json = bids_dataset.eeg_json(bids_file)
304
+ except Exception as e:
305
+ logger.error("Error getting eeg_json: %s", str(e))
306
+ eeg_json = None
307
+
308
+ bids_dependencies_files = self.config["bids_dependencies_files"]
309
+ bidsdependencies = []
310
+ for extension in bids_dependencies_files:
311
+ try:
312
+ dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
313
+ dep_path = [
314
+ str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
315
+ ]
316
+ bidsdependencies.extend(dep_path)
317
+ except Exception:
318
+ pass
319
+
320
+ bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
321
+
322
+ # Define field extraction functions with error handling
323
+ field_extractors = {
324
+ "data_name": lambda: f"{bids_dataset.dataset}_{file}",
325
+ "dataset": lambda: bids_dataset.dataset,
326
+ "bidspath": lambda: openneuro_path,
327
+ "subject": lambda: bids_dataset.get_bids_file_attribute(
328
+ "subject", bids_file
329
+ ),
330
+ "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
331
+ "session": lambda: bids_dataset.get_bids_file_attribute(
332
+ "session", bids_file
333
+ ),
334
+ "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
335
+ "modality": lambda: bids_dataset.get_bids_file_attribute(
336
+ "modality", bids_file
337
+ ),
338
+ "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
339
+ "sfreq", bids_file
340
+ ),
341
+ "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
342
+ "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
343
+ "participant_tsv": lambda: participants_tsv,
344
+ "eeg_json": lambda: eeg_json,
345
+ "bidsdependencies": lambda: bidsdependencies,
346
+ }
347
+
348
+ # Dynamically populate attrs with error handling
349
+ for field, extractor in field_extractors.items():
350
+ try:
351
+ attrs[field] = extractor()
352
+ except Exception as e:
353
+ logger.error("Error extracting %s : %s", field, str(e))
354
+ attrs[field] = None
355
+
356
+ return attrs
357
+
358
+ def add_bids_dataset(
359
+ self, dataset: str, data_dir: str, overwrite: bool = True
360
+ ) -> None:
361
+ """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
362
+ under the given dataset name.
363
+
364
+ Parameters
365
+ ----------
366
+ dataset : str)
367
+ The name of the dataset to be added (e.g., "ds002718").
368
+ data_dir : str
369
+ The path to the BIDS dataset directory.
370
+ overwrite : bool
371
+ Whether to overwrite/update existing records in the database.
372
+
373
+ """
374
+ if self.is_public:
375
+ raise ValueError("This operation is not allowed for public users")
376
+
377
+ if not overwrite and self.exist({"dataset": dataset}):
378
+ logger.info("Dataset %s already exists in the database", dataset)
379
+ return
380
+ try:
381
+ bids_dataset = EEGBIDSDataset(
382
+ data_dir=data_dir,
383
+ dataset=dataset,
384
+ )
385
+ except Exception as e:
386
+ logger.error("Error creating bids dataset %s: $s", dataset, str(e))
387
+ raise e
388
+ requests = []
389
+ for bids_file in bids_dataset.get_files():
390
+ try:
391
+ data_id = f"{dataset}_{Path(bids_file).name}"
392
+
393
+ if self.exist({"data_name": data_id}):
394
+ if overwrite:
395
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
396
+ bids_dataset, bids_file
397
+ )
398
+ requests.append(self.update_request(eeg_attrs))
399
+ else:
400
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
401
+ bids_dataset, bids_file
402
+ )
403
+ requests.append(self.add_request(eeg_attrs))
404
+ except Exception as e:
405
+ logger.error("Error adding record %s", bids_file)
406
+ logger.error(str(e))
407
+
408
+ logger.info("Number of requests: %s", len(requests))
409
+
410
+ if requests:
411
+ result = self.__collection.bulk_write(requests, ordered=False)
412
+ logger.info("Inserted: %s ", result.inserted_count)
413
+ logger.info("Modified: %s ", result.modified_count)
414
+ logger.info("Deleted: %s", result.deleted_count)
415
+ logger.info("Upserted: %s", result.upserted_count)
416
+ logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
417
+
418
+ def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
419
+ """Retrieve a list of EEG data arrays that match the given query. See also
420
+ the `find()` method for details on the query format.
421
+
422
+ Parameters
423
+ ----------
424
+ query : dict
425
+ A dictionary that specifies the query to be executed; this is a reference
426
+ document that is used to match records in the MongoDB collection.
427
+
428
+ Returns
429
+ -------
430
+ A list of xarray DataArray objects containing the EEG data for each matching record.
431
+
432
+ Notes
433
+ -----
434
+ Retrieval is done in parallel, and the downloaded data are not cached locally.
435
+
436
+ """
437
+ sessions = self.find(query)
438
+ results = []
439
+ if sessions:
440
+ logger.info("Found %s records", len(sessions))
441
+ results = Parallel(
442
+ n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
443
+ )(
444
+ delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
445
+ for session in sessions
446
+ )
447
+ return results
448
+
449
+ def add_request(self, record: dict):
450
+ """Internal helper method to create a MongoDB insertion request for a record."""
451
+ return InsertOne(record)
452
+
453
+ def add(self, record: dict):
454
+ """Add a single record to the MongoDB collection."""
455
+ try:
456
+ self.__collection.insert_one(record)
457
+ except ValueError as e:
458
+ logger.error("Validation error for record: %s ", record["data_name"])
459
+ logger.error(e)
460
+ except:
461
+ logger.error("Error adding record: %s ", record["data_name"])
462
+
463
+ def update_request(self, record: dict):
464
+ """Internal helper method to create a MongoDB update request for a record."""
465
+ return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
466
+
467
+ def update(self, record: dict):
468
+ """Update a single record in the MongoDB collection."""
469
+ try:
470
+ self.__collection.update_one(
471
+ {"data_name": record["data_name"]}, {"$set": record}
472
+ )
473
+ except: # silent failure
474
+ logger.error("Error updating record: %s", record["data_name"])
475
+
476
+ def remove_field(self, record, field):
477
+ """Remove a specific field from a record in the MongoDB collection."""
478
+ self.__collection.update_one(
479
+ {"data_name": record["data_name"]}, {"$unset": {field: 1}}
480
+ )
481
+
482
+ def remove_field_from_db(self, field):
483
+ """Removed all occurrences of a specific field from all records in the MongoDB
484
+ collection. WARNING: this operation is destructive and should be used with caution.
485
+ """
486
+ self.__collection.update_many({}, {"$unset": {field: 1}})
487
+
488
+ @property
489
+ def collection(self):
490
+ """Return the MongoDB collection object."""
491
+ return self.__collection
492
+
493
+ def close(self):
494
+ """Close the MongoDB client connection."""
495
+ if hasattr(self, "_EEGDash__client"):
496
+ self.__client.close()
497
+
498
+ def __del__(self):
499
+ """Ensure connection is closed when object is deleted."""
500
+ self.close()
501
+
502
+
503
+ class EEGDashDataset(BaseConcatDataset):
504
+ def __init__(
505
+ self,
506
+ query: dict | None = None,
507
+ data_dir: str | list | None = None,
508
+ dataset: str | list | None = None,
509
+ description_fields: list[str] = [
510
+ "subject",
511
+ "session",
512
+ "run",
513
+ "task",
514
+ "age",
515
+ "gender",
516
+ "sex",
517
+ ],
518
+ cache_dir: str = ".eegdash_cache",
519
+ s3_bucket: str | None = None,
520
+ **kwargs,
521
+ ):
522
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
523
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
524
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
525
+
526
+ Parameters
527
+ ----------
528
+ query : dict | None
529
+ Optionally a dictionary that specifies the query to be executed; see
530
+ EEGDash.find() for details on the query format.
531
+ data_dir : str | list[str] | None
532
+ Optionally a string or a list of strings specifying one or more local
533
+ BIDS dataset directories from which to load the EEG data files. Exactly one
534
+ of query or data_dir must be provided.
535
+ dataset : str | list[str] | None
536
+ If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
537
+ description_fields : list[str]
538
+ A list of fields to be extracted from the dataset records
539
+ and included in the returned data description(s). Examples are typical
540
+ subject metadata fields such as "subject", "session", "run", "task", etc.;
541
+ see also data_config.description_fields for the default set of fields.
542
+ cache_dir : str
543
+ A directory where the dataset will be cached locally.
544
+ s3_bucket : str | None
545
+ An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
546
+ default OpenNeuro bucket for loading data files
547
+ kwargs : dict
548
+ Additional keyword arguments to be passed to the EEGDashBaseDataset
549
+ constructor.
550
+
551
+ """
552
+ self.cache_dir = cache_dir
553
+ self.s3_bucket = s3_bucket
554
+ if query:
555
+ datasets = self.find_datasets(query, description_fields, **kwargs)
556
+ elif data_dir:
557
+ if isinstance(data_dir, str):
558
+ datasets = self.load_bids_dataset(
559
+ dataset, data_dir, description_fields, s3_bucket
560
+ )
561
+ else:
562
+ assert len(data_dir) == len(dataset), (
563
+ "Number of datasets and their directories must match"
564
+ )
565
+ datasets = []
566
+ for i, _ in enumerate(data_dir):
567
+ datasets.extend(
568
+ self.load_bids_dataset(
569
+ dataset[i], data_dir[i], description_fields, s3_bucket
570
+ )
571
+ )
572
+
573
+ super().__init__(datasets)
574
+
575
+ def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
576
+ """Helper to recursively search for a key in a nested dictionary structure; returns
577
+ the value associated with the first occurrence of the key, or None if not found.
578
+ """
579
+ if isinstance(data, dict):
580
+ if target_key in data:
581
+ return data[target_key]
582
+ for value in data.values():
583
+ result = self.find_key_in_nested_dict(value, target_key)
584
+ if result is not None:
585
+ return result
586
+ return None
587
+
588
+ def find_datasets(
589
+ self, query: dict[str, Any], description_fields: list[str], **kwargs
590
+ ) -> list[EEGDashBaseDataset]:
591
+ """Helper method to find datasets in the MongoDB collection that satisfy the
592
+ given query and return them as a list of EEGDashBaseDataset objects.
593
+
594
+ Parameters
595
+ ----------
596
+ query : dict
597
+ The query object, as in EEGDash.find().
598
+ description_fields : list[str]
599
+ A list of fields to be extracted from the dataset records and included in
600
+ the returned dataset description(s).
601
+ kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
602
+ constructor.
603
+
604
+ Returns
605
+ -------
606
+ list :
607
+ A list of EEGDashBaseDataset objects that match the query.
608
+
609
+ """
610
+ eeg_dash_instance = EEGDash()
611
+ try:
612
+ datasets = []
613
+ for record in eeg_dash_instance.find(query):
614
+ description = {}
615
+ for field in description_fields:
616
+ value = self.find_key_in_nested_dict(record, field)
617
+ if value is not None:
618
+ description[field] = value
619
+ datasets.append(
620
+ EEGDashBaseDataset(
621
+ record,
622
+ self.cache_dir,
623
+ self.s3_bucket,
624
+ description=description,
625
+ **kwargs,
626
+ )
627
+ )
628
+ return datasets
629
+ finally:
630
+ eeg_dash_instance.close()
631
+
632
+ def load_bids_dataset(
633
+ self,
634
+ dataset,
635
+ data_dir,
636
+ description_fields: list[str],
637
+ s3_bucket: str | None = None,
638
+ **kwargs,
639
+ ):
640
+ """Helper method to load a single local BIDS dataset and return it as a list of
641
+ EEGDashBaseDatasets (one for each recording in the dataset).
642
+
643
+ Parameters
644
+ ----------
645
+ dataset : str
646
+ A name for the dataset to be loaded (e.g., "ds002718").
647
+ data_dir : str
648
+ The path to the local BIDS dataset directory.
649
+ description_fields : list[str]
650
+ A list of fields to be extracted from the dataset records
651
+ and included in the returned dataset description(s).
652
+
653
+ """
654
+
655
+ def get_base_dataset_from_bids_file(
656
+ bids_dataset: EEGBIDSDataset,
657
+ bids_file: str,
658
+ eeg_dash_instance: EEGDash,
659
+ s3_bucket: str | None,
660
+ ) -> EEGDashBaseDataset:
661
+ """Instantiate a single EEGDashBaseDataset given a local BIDS file. Note
662
+ this does not actually load the data from disk, but will access the metadata.
663
+ """
664
+ record = eeg_dash_instance.load_eeg_attrs_from_bids_file(
665
+ bids_dataset, bids_file
666
+ )
667
+ description = {}
668
+ for field in description_fields:
669
+ value = self.find_key_in_nested_dict(record, field)
670
+ if value is not None:
671
+ description[field] = value
672
+ return EEGDashBaseDataset(
673
+ record, self.cache_dir, s3_bucket, description=description, **kwargs
674
+ )
675
+
676
+ bids_dataset = EEGBIDSDataset(
677
+ data_dir=data_dir,
678
+ dataset=dataset,
679
+ )
680
+ eeg_dash_instance = EEGDash()
681
+ try:
682
+ datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
683
+ delayed(get_base_dataset_from_bids_file)(
684
+ bids_dataset, bids_file, eeg_dash_instance, s3_bucket
685
+ )
686
+ for bids_file in bids_dataset.get_files()
687
+ )
688
+ return datasets
689
+ finally:
690
+ eeg_dash_instance.close()