eegdash 0.1.0__py3-none-any.whl → 0.2.1.dev178237806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_config.py CHANGED
@@ -1,5 +1,8 @@
1
1
  config = {
2
2
  "required_fields": ["data_name"],
3
+ # Default set of user-facing primary record attributes expected in the database. Records
4
+ # where any of these are missing will be loaded with the respective attribute set to None.
5
+ # Additional fields may be returned if they are present in the database, notably bidsdependencies.
3
6
  "attributes": {
4
7
  "data_name": "str",
5
8
  "dataset": "str",
@@ -11,9 +14,12 @@ config = {
11
14
  "sampling_frequency": "float",
12
15
  "modality": "str",
13
16
  "nchans": "int",
14
- "ntimes": "int",
17
+ "ntimes": "int", # note: this is really the number of seconds in the data, rounded down
15
18
  },
19
+ # queryable descriptive fields for a given recording
16
20
  "description_fields": ["subject", "session", "run", "task", "age", "gender", "sex"],
21
+ # list of filenames that may be present in the BIDS dataset directory that are used
22
+ # to load and interpret a given BIDS recording.
17
23
  "bids_dependencies_files": [
18
24
  "dataset_description.json",
19
25
  "participants.tsv",
eegdash/data_utils.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  import re
4
- import sys
5
- import tempfile
6
5
  from pathlib import Path
6
+ from typing import Any
7
7
 
8
8
  import mne
9
9
  import mne_bids
@@ -12,7 +12,7 @@ import pandas as pd
12
12
  import s3fs
13
13
  from bids import BIDSLayout
14
14
  from joblib import Parallel, delayed
15
- from mne._fiff.utils import _find_channels, _read_segments_file
15
+ from mne._fiff.utils import _read_segments_file
16
16
  from mne.io import BaseRaw
17
17
  from mne_bids import (
18
18
  BIDSPath,
@@ -20,51 +20,62 @@ from mne_bids import (
20
20
 
21
21
  from braindecode.datasets import BaseDataset
22
22
 
23
+ logger = logging.getLogger("eegdash")
23
24
 
24
- class EEGDashBaseDataset(BaseDataset):
25
- """Returns samples from an mne.io.Raw object along with a target.
26
25
 
27
- Dataset which serves samples from an mne.io.Raw object along with a target.
28
- The target is unique for the dataset, and is obtained through the
29
- `description` attribute.
26
+ class EEGDashBaseDataset(BaseDataset):
27
+ """A single EEG recording hosted on AWS S3 and cached locally upon first access.
30
28
 
31
- Parameters
32
- ----------
33
- raw : mne.io.Raw
34
- Continuous data.
35
- description : dict | pandas.Series | None
36
- Holds additional description about the continuous signal / subject.
37
- target_name : str | tuple | None
38
- Name(s) of the index in `description` that should be used to provide the
39
- target (e.g., to be used in a prediction task later on).
40
- transform : callable | None
41
- On-the-fly transform applied to the example before it is returned.
29
+ This is a subclass of braindecode's BaseDataset, which can consequently be used in
30
+ conjunction with the preprocessing and training pipelines of braindecode.
42
31
  """
43
32
 
44
33
  AWS_BUCKET = "s3://openneuro.org"
45
34
 
46
- def __init__(self, record, cache_dir, **kwargs):
35
+ def __init__(
36
+ self,
37
+ record: dict[str, Any],
38
+ cache_dir: str,
39
+ s3_bucket: str | None = None,
40
+ **kwargs,
41
+ ):
42
+ """Create a new EEGDashBaseDataset instance. Users do not usually need to call this
43
+ directly -- instead use the EEGDashDataset class to load a collection of these
44
+ recordings from a local BIDS folder or using a database query.
45
+
46
+ Parameters
47
+ ----------
48
+ record : dict
49
+ A fully resolved metadata record for the data to load.
50
+ cache_dir : str
51
+ A local directory where the data will be cached.
52
+ kwargs : dict
53
+ Additional keyword arguments to pass to the BaseDataset constructor.
54
+
55
+ """
47
56
  super().__init__(None, **kwargs)
48
57
  self.record = record
49
58
  self.cache_dir = Path(cache_dir)
50
59
  bids_kwargs = self.get_raw_bids_args()
60
+
51
61
  self.bidspath = BIDSPath(
52
62
  root=self.cache_dir / record["dataset"],
53
63
  datatype="eeg",
54
64
  suffix="eeg",
55
65
  **bids_kwargs,
56
66
  )
67
+ self.s3_bucket = s3_bucket if s3_bucket else self.AWS_BUCKET
57
68
  self.s3file = self.get_s3path(record["bidspath"])
58
69
  self.filecache = self.cache_dir / record["bidspath"]
59
70
  self.bids_dependencies = record["bidsdependencies"]
60
71
  self._raw = None
61
- # if os.path.exists(self.filecache):
62
- # self.raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
63
72
 
64
- def get_s3path(self, filepath):
65
- return f"{self.AWS_BUCKET}/{filepath}"
73
+ def get_s3path(self, filepath: str) -> str:
74
+ """Helper to form an AWS S3 URI for the given relative filepath."""
75
+ return f"{self.s3_bucket}/{filepath}"
66
76
 
67
- def _download_s3(self):
77
+ def _download_s3(self) -> None:
78
+ """Fetch the given data from its S3 location and cache it locally."""
68
79
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
69
80
  filesystem = s3fs.S3FileSystem(
70
81
  anon=True, client_kwargs={"region_name": "us-east-2"}
@@ -72,7 +83,10 @@ class EEGDashBaseDataset(BaseDataset):
72
83
  filesystem.download(self.s3file, self.filecache)
73
84
  self.filenames = [self.filecache]
74
85
 
75
- def _download_dependencies(self):
86
+ def _download_dependencies(self) -> None:
87
+ """Download all BIDS dependency files (metadata files, recording sidecar files)
88
+ from S3 and cache them locally.
89
+ """
76
90
  filesystem = s3fs.S3FileSystem(
77
91
  anon=True, client_kwargs={"region_name": "us-east-2"}
78
92
  )
@@ -83,11 +97,15 @@ class EEGDashBaseDataset(BaseDataset):
83
97
  filepath.parent.mkdir(parents=True, exist_ok=True)
84
98
  filesystem.download(s3path, filepath)
85
99
 
86
- def get_raw_bids_args(self):
100
+ def get_raw_bids_args(self) -> dict[str, Any]:
101
+ """Helper to restrict the metadata record to the fields needed to locate a BIDS
102
+ recording.
103
+ """
87
104
  desired_fields = ["subject", "session", "task", "run"]
88
105
  return {k: self.record[k] for k in desired_fields if self.record[k]}
89
106
 
90
- def check_and_get_raw(self):
107
+ def check_and_get_raw(self) -> None:
108
+ """Download the S3 file and BIDS dependencies if not already cached."""
91
109
  if not os.path.exists(self.filecache): # not preload
92
110
  if self.bids_dependencies:
93
111
  self._download_dependencies()
@@ -95,9 +113,10 @@ class EEGDashBaseDataset(BaseDataset):
95
113
  if self._raw is None:
96
114
  self._raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
97
115
 
98
- def __getitem__(self, index):
99
- # self.check_and_get_raw()
116
+ # === BaseDataset and PyTorch Dataset interface ===
100
117
 
118
+ def __getitem__(self, index):
119
+ """Main function to access a sample from the dataset."""
101
120
  X = self.raw[:, index][0]
102
121
  y = None
103
122
  if self.target_name is not None:
@@ -108,14 +127,20 @@ class EEGDashBaseDataset(BaseDataset):
108
127
  X = self.transform(X)
109
128
  return X, y
110
129
 
111
- def __len__(self):
130
+ def __len__(self) -> int:
131
+ """Return the number of samples in the dataset."""
112
132
  if self._raw is None:
133
+ # FIXME: this is a bit strange and should definitely not change as a side effect
134
+ # of accessing the data (which it will, since ntimes is the actual length but rounded down)
113
135
  return int(self.record["ntimes"] * self.record["sampling_frequency"])
114
136
  else:
115
137
  return len(self._raw)
116
138
 
117
139
  @property
118
140
  def raw(self):
141
+ """Return the MNE Raw object for this recording. This will perform the actual
142
+ retrieval if not yet done so.
143
+ """
119
144
  if self._raw is None:
120
145
  self.check_and_get_raw()
121
146
  return self._raw
@@ -126,50 +151,44 @@ class EEGDashBaseDataset(BaseDataset):
126
151
 
127
152
 
128
153
  class EEGDashBaseRaw(BaseRaw):
129
- r"""MNE Raw object from EEG-Dash connection with Openneuro S3 file.
154
+ """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
155
+ (when _read_segment is called) and caches it locally. Currently for internal use.
130
156
 
131
157
  Parameters
132
158
  ----------
133
159
  input_fname : path-like
134
160
  Path to the S3 file
135
- eog : list | tuple | 'auto'
136
- Names or indices of channels that should be designated EOG channels.
137
- If 'auto', the channel names containing ``EOG`` or ``EYE`` are used.
138
- Defaults to empty tuple.
139
- %(preload)s
140
- Note that preload=False will be effective only if the data is stored
141
- in a separate binary file.
142
- %(uint16_codec)s
143
- %(montage_units)s
144
- %(verbose)s
161
+ metadata : dict
162
+ The metadata record for the recording (e.g., from the database).
163
+ preload : bool
164
+ Whether to pre-loaded the data before the first access.
165
+ cache_dir : str
166
+ Local path under which the data will be cached.
167
+ bids_dependencies : list
168
+ List of additional BIDS metadata files that should be downloaded and cached
169
+ alongside the main recording file.
170
+ verbose : str | int | None
171
+ Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
145
172
 
146
173
  See Also
147
174
  --------
148
175
  mne.io.Raw : Documentation of attributes and methods.
149
176
 
150
- Notes
151
- -----
152
- .. versionadded:: 0.11.0
153
177
  """
154
178
 
155
179
  AWS_BUCKET = "s3://openneuro.org"
156
180
 
157
181
  def __init__(
158
182
  self,
159
- input_fname,
160
- metadata,
161
- eog=(),
162
- preload=False,
183
+ input_fname: str,
184
+ metadata: dict[str, Any],
185
+ preload: bool = False,
163
186
  *,
164
- cache_dir="./.eegdash_cache",
165
- bids_dependencies: list = [],
166
- uint16_codec=None,
167
- montage_units="auto",
168
- verbose=None,
187
+ cache_dir: str = "./.eegdash_cache",
188
+ bids_dependencies: list[str] = [],
189
+ verbose: Any = None,
169
190
  ):
170
- """
171
- Get to work with S3 endpoint first, no caching
172
- """
191
+ """Get to work with S3 endpoint first, no caching"""
173
192
  # Create a simple RawArray
174
193
  sfreq = metadata["sfreq"] # Sampling frequency
175
194
  n_times = metadata["n_times"]
@@ -237,6 +256,20 @@ class EEGDashBaseRaw(BaseRaw):
237
256
 
238
257
 
239
258
  class EEGBIDSDataset:
259
+ """A one-stop shop interface to a local BIDS dataset containing EEG recordings.
260
+
261
+ This is mainly tailored to the needs of EEGDash application and is used to centralize
262
+ interactions with the BIDS dataset, such as parsing the metadata.
263
+
264
+ Parameters
265
+ ----------
266
+ data_dir : str | Path
267
+ The path to the local BIDS dataset directory.
268
+ dataset : str
269
+ A name for the dataset.
270
+
271
+ """
272
+
240
273
  ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
241
274
  RAW_EXTENSIONS = {
242
275
  ".set": [".set", ".fdt"], # eeglab
@@ -270,19 +303,13 @@ class EEGBIDSDataset:
270
303
  "Unable to construct EEG dataset. No EEG recordings found."
271
304
  )
272
305
  assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
273
- # temp_dir = (Path().resolve() / 'data')
274
- # if not os.path.exists(temp_dir):
275
- # os.mkdir(temp_dir)
276
- # if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
277
- # self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
278
- # np.save(temp_dir / f'{dataset}_files.npy', self.files)
279
- # else:
280
- # self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
281
-
282
- def check_eeg_dataset(self):
306
+
307
+ def check_eeg_dataset(self) -> bool:
308
+ """Check if the dataset is EEG."""
283
309
  return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
284
310
 
285
- def get_recordings(self, layout: BIDSLayout):
311
+ def get_recordings(self, layout: BIDSLayout) -> list[str]:
312
+ """Get a list of all EEG recording files in the BIDS layout."""
286
313
  files = []
287
314
  for ext, exts in self.RAW_EXTENSIONS.items():
288
315
  files = layout.get(extension=ext, return_type="filename")
@@ -290,11 +317,15 @@ class EEGBIDSDataset:
290
317
  break
291
318
  return files
292
319
 
293
- def get_relative_bidspath(self, filename):
320
+ def get_relative_bidspath(self, filename: str) -> str:
321
+ """Make the given file path relative to the BIDS directory."""
294
322
  bids_parent_dir = self.bidsdir.parent.absolute()
295
323
  return str(Path(filename).relative_to(bids_parent_dir))
296
324
 
297
- def get_property_from_filename(self, property, filename):
325
+ def get_property_from_filename(self, property: str, filename: str) -> str:
326
+ """Parse a property out of a BIDS-compliant filename. Returns an empty string
327
+ if not found.
328
+ """
298
329
  import platform
299
330
 
300
331
  if platform.system() == "Windows":
@@ -303,25 +334,38 @@ class EEGBIDSDataset:
303
334
  lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
304
335
  return lookup.group(1) if lookup else ""
305
336
 
306
- def merge_json_inheritance(self, json_files):
307
- """
308
- Merge list of json files found by get_bids_file_inheritance,
309
- expecting the order (from left to right) is from lowest level to highest level,
310
- and return a merged dictionary
337
+ def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
338
+ """Internal helper to merge list of json files found by get_bids_file_inheritance,
339
+ expecting the order (from left to right) is from lowest
340
+ level to highest level, and return a merged dictionary
311
341
  """
312
342
  json_files.reverse()
313
343
  json_dict = {}
314
344
  for f in json_files:
315
- json_dict.update(json.load(open(f)))
345
+ json_dict.update(json.load(open(f))) # FIXME: should close file
316
346
  return json_dict
317
347
 
318
- def get_bids_file_inheritance(self, path, basename, extension):
319
- """
320
- Get all files with given extension that applies to the basename file
321
- following the BIDS inheritance principle in the order of lowest level first
322
- @param
323
- basename: bids file basename without _eeg.set extension for example
324
- extension: e.g. channels.tsv
348
+ def get_bids_file_inheritance(
349
+ self, path: str | Path, basename: str, extension: str
350
+ ) -> list[Path]:
351
+ """Get all file paths that apply to the basename file in the specified directory
352
+ and that end with the specified suffix, recursively searching parent directories
353
+ (following the BIDS inheritance principle in the order of lowest level first).
354
+
355
+ Parameters
356
+ ----------
357
+ path : str | Path
358
+ The directory path to search for files.
359
+ basename : str
360
+ BIDS file basename without _eeg.set extension for example
361
+ extension : str
362
+ Only consider files that end with the specified suffix; e.g. channels.tsv
363
+
364
+ Returns
365
+ -------
366
+ list[Path]
367
+ A list of file paths that match the given basename and extension.
368
+
325
369
  """
326
370
  top_level_files = ["README", "dataset_description.json", "participants.tsv"]
327
371
  bids_files = []
@@ -352,17 +396,25 @@ class EEGBIDSDataset:
352
396
  )
353
397
  return bids_files
354
398
 
355
- def get_bids_metadata_files(self, filepath, metadata_file_extension):
356
- """
357
- (Wrapper for self.get_bids_file_inheritance)
358
- Get all BIDS metadata files that are associated with the given filepath, following the BIDS inheritance principle.
359
-
360
- Args:
361
- filepath (str or Path): The filepath to get the associated metadata files for.
362
- metadata_files_extensions (list): A list of file extensions to search for metadata files.
399
+ def get_bids_metadata_files(
400
+ self, filepath: str | Path, metadata_file_extension: list[str]
401
+ ) -> list[Path]:
402
+ """Retrieve all metadata file paths that apply to a given data file path and that
403
+ end with a specific suffix (following the BIDS inheritance principle).
404
+
405
+ Parameters
406
+ ----------
407
+ filepath: str | Path
408
+ The filepath to get the associated metadata files for.
409
+ metadata_file_extension : str
410
+ Consider only metadata files that end with the specified suffix,
411
+ e.g., channels.tsv or eeg.json
412
+
413
+ Returns
414
+ -------
415
+ list[Path]:
416
+ A list of filepaths for all matching metadata files
363
417
 
364
- Returns:
365
- list: A list of filepaths for all the associated metadata files
366
418
  """
367
419
  if isinstance(filepath, str):
368
420
  filepath = Path(filepath)
@@ -376,7 +428,11 @@ class EEGBIDSDataset:
376
428
  )
377
429
  return meta_files
378
430
 
379
- def scan_directory(self, directory, extension):
431
+ def scan_directory(self, directory: str, extension: str) -> list[Path]:
432
+ """Return a list of file paths that end with the given extension in the specified
433
+ directory. Ignores certain special directories like .git, .datalad, derivatives,
434
+ and code.
435
+ """
380
436
  result_files = []
381
437
  directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
382
438
  with os.scandir(directory) as entries:
@@ -391,14 +447,35 @@ class EEGBIDSDataset:
391
447
  return result_files
392
448
 
393
449
  def get_files_with_extension_parallel(
394
- self, directory, extension=".set", max_workers=-1
395
- ):
450
+ self, directory: str, extension: str = ".set", max_workers: int = -1
451
+ ) -> list[Path]:
452
+ """Efficiently scan a directory and its subdirectories for files that end with
453
+ the given extension.
454
+
455
+ Parameters
456
+ ----------
457
+ directory : str
458
+ The root directory to scan for files.
459
+ extension : str
460
+ Only consider files that end with this suffix, e.g. '.set'.
461
+ max_workers : int
462
+ Optionally specify the maximum number of worker threads to use for parallel scanning.
463
+ Defaults to all available CPU cores if set to -1.
464
+
465
+ Returns
466
+ -------
467
+ list[Path]:
468
+ A list of filepaths for all matching metadata files
469
+
470
+ """
396
471
  result_files = []
397
472
  dirs_to_scan = [directory]
398
473
 
399
474
  # Use joblib.Parallel and delayed to parallelize directory scanning
400
475
  while dirs_to_scan:
401
- print(f"Scanning {len(dirs_to_scan)} directories...", dirs_to_scan)
476
+ logger.info(
477
+ f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
478
+ )
402
479
  # Run the scan_directory function in parallel across directories
403
480
  results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
404
481
  delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
@@ -412,12 +489,18 @@ class EEGBIDSDataset:
412
489
  dirs_to_scan.append(path) # Queue up subdirectories to scan
413
490
  else:
414
491
  result_files.append(path) # Add files to the final result
415
- print(f"Current number of files: {len(result_files)}")
492
+ logger.info(f"Found {len(result_files)} files.")
416
493
 
417
494
  return result_files
418
495
 
419
- def load_and_preprocess_raw(self, raw_file, preprocess=False):
420
- print(f"Loading {raw_file}")
496
+ def load_and_preprocess_raw(
497
+ self, raw_file: str, preprocess: bool = False
498
+ ) -> np.ndarray:
499
+ """Utility function to load a raw data file with MNE and apply some simple
500
+ (hardcoded) preprocessing and return as a numpy array. Not meant for purposes
501
+ other than testing or debugging.
502
+ """
503
+ logger.info(f"Loading raw data from {raw_file}")
421
504
  EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
422
505
 
423
506
  if preprocess:
@@ -429,9 +512,6 @@ class EEGBIDSDataset:
429
512
  sfreq = 128
430
513
  if EEG.info["sfreq"] != sfreq:
431
514
  EEG = EEG.resample(sfreq)
432
- # # normalize data to zero mean and unit variance
433
- # scalar = preprocessing.StandardScaler()
434
- # mat_data = scalar.fit_transform(mat_data.T).T # scalar normalize for each feature and expects shape data x features
435
515
 
436
516
  mat_data = EEG.get_data()
437
517
 
@@ -439,17 +519,22 @@ class EEGBIDSDataset:
439
519
  raise ValueError("Expect raw data to be CxT dimension")
440
520
  return mat_data
441
521
 
442
- def get_files(self):
522
+ def get_files(self) -> list[Path]:
523
+ """Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
443
524
  return self.files
444
525
 
445
- def resolve_bids_json(self, json_files: list):
446
- """
447
- Resolve the BIDS JSON files and return a dictionary of the resolved values.
448
- Args:
449
- json_files (list): A list of JSON files to resolve in order of leaf level first
526
+ def resolve_bids_json(self, json_files: list[str]) -> dict:
527
+ """Resolve the BIDS JSON files and return a dictionary of the resolved values.
450
528
 
451
- Returns:
529
+ Parameters
530
+ ----------
531
+ json_files : list
532
+ A list of JSON file paths to resolve in order of leaf level first.
533
+
534
+ Returns
535
+ -------
452
536
  dict: A dictionary of the resolved values.
537
+
453
538
  """
454
539
  if len(json_files) == 0:
455
540
  raise ValueError("No JSON files provided")
@@ -461,7 +546,10 @@ class EEGBIDSDataset:
461
546
  json_dict.update(json.load(f))
462
547
  return json_dict
463
548
 
464
- def get_bids_file_attribute(self, attribute, data_filepath):
549
+ def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
550
+ """Retrieve a specific attribute from the BIDS file metadata applicable
551
+ to the provided recording file path.
552
+ """
465
553
  entities = self.layout.parse_file_entities(data_filepath)
466
554
  bidsfile = self.layout.get(**entities)[0]
467
555
  attributes = bidsfile.get_entities(metadata="all")
@@ -478,27 +566,32 @@ class EEGBIDSDataset:
478
566
  attribute_value = attributes.get(attribute_mapping.get(attribute), None)
479
567
  return attribute_value
480
568
 
481
- def channel_labels(self, data_filepath):
569
+ def channel_labels(self, data_filepath: str) -> list[str]:
570
+ """Get a list of channel labels for the given data file path."""
482
571
  channels_tsv = pd.read_csv(
483
572
  self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
484
573
  )
485
574
  return channels_tsv["name"].tolist()
486
575
 
487
- def channel_types(self, data_filepath):
576
+ def channel_types(self, data_filepath: str) -> list[str]:
577
+ """Get a list of channel types for the given data file path."""
488
578
  channels_tsv = pd.read_csv(
489
579
  self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
490
580
  )
491
581
  return channels_tsv["type"].tolist()
492
582
 
493
- def num_times(self, data_filepath):
583
+ def num_times(self, data_filepath: str) -> int:
584
+ """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
494
585
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
495
586
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
496
587
  return int(
497
588
  eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
498
589
  )
499
590
 
500
- def subject_participant_tsv(self, data_filepath):
501
- """Get participants_tsv info of a subject based on filepath"""
591
+ def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
592
+ """Get BIDS participants.tsv record for the subject to which the given file
593
+ path corresponds, as a dictionary.
594
+ """
502
595
  participants_tsv = pd.read_csv(
503
596
  self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
504
597
  )
@@ -510,12 +603,16 @@ class EEGBIDSDataset:
510
603
  subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
511
604
  return participants_tsv.loc[subject].to_dict()
512
605
 
513
- def eeg_json(self, data_filepath):
606
+ def eeg_json(self, data_filepath: str) -> dict[str, Any]:
607
+ """Get BIDS eeg.json metadata for the given data file path."""
514
608
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
515
609
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
516
610
  return eeg_json_dict
517
611
 
518
- def channel_tsv(self, data_filepath):
612
+ def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
613
+ """Get BIDS channels.tsv metadata for the given data file path, as a dictionary
614
+ of lists and/or single values.
615
+ """
519
616
  channels_tsv = pd.read_csv(
520
617
  self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
521
618
  )
eegdash/dataset.py ADDED
@@ -0,0 +1,69 @@
1
+ from .api import EEGDashDataset
2
+
3
+
4
+ class EEGChallengeDataset(EEGDashDataset):
5
+ def __init__(
6
+ self,
7
+ release: str = "R5",
8
+ query: dict | None = None,
9
+ cache_dir: str = ".eegdash_cache",
10
+ s3_bucket: str | None = "s3://nmdatasets/NeurIPS25/",
11
+ **kwargs,
12
+ ):
13
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
14
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
15
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
16
+
17
+ Parameters
18
+ ----------
19
+ release: str
20
+ Release name. Can be one of ["R1", ..., "R11"]
21
+ query : dict | None
22
+ Optionally a dictionary that specifies a query to be executed,
23
+ in addition to the dataset (automatically inferred from the release argument).
24
+ See EEGDash.find() for details on the query format.
25
+ cache_dir : str
26
+ A directory where the dataset will be cached locally.
27
+ s3_bucket : str | None
28
+ An optional S3 bucket URI to use instead of the
29
+ default OpenNeuro bucket for loading data files.
30
+ kwargs : dict
31
+ Additional keyword arguments to be passed to the EEGDashDataset
32
+ constructor.
33
+
34
+ """
35
+ dsnumber_release_map = {
36
+ "R11": "ds005516",
37
+ "R10": "ds005515",
38
+ "R9": "ds005514",
39
+ "R8": "ds005512",
40
+ "R7": "ds005511",
41
+ "R6": "ds005510",
42
+ "R4": "ds005508",
43
+ "R5": "ds005509",
44
+ "R3": "ds005507",
45
+ "R2": "ds005506",
46
+ "R1": "ds005505",
47
+ }
48
+
49
+ self.release = release
50
+ if release not in dsnumber_release_map:
51
+ raise ValueError(f"Unknown release: {release}")
52
+
53
+ dataset = dsnumber_release_map[release]
54
+ if query is None:
55
+ query = {"dataset": dataset}
56
+ elif "dataset" not in query:
57
+ query["dataset"] = dataset
58
+ elif query["dataset"] != dataset:
59
+ raise ValueError(
60
+ f"Query dataset {query['dataset']} does not match the release {release} "
61
+ f"which corresponds to dataset {dataset}."
62
+ )
63
+
64
+ super().__init__(
65
+ query=query,
66
+ cache_dir=cache_dir,
67
+ s3_bucket=f"{s3_bucket}/{release}_L100",
68
+ **kwargs,
69
+ )