eegdash 0.4.0.dev173498563__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_utils.py CHANGED
@@ -20,13 +20,10 @@ from typing import Any
20
20
 
21
21
  import mne
22
22
  import mne_bids
23
- import numpy as np
24
23
  import pandas as pd
25
- from bids import BIDSLayout
26
- from joblib import Parallel, delayed
27
24
  from mne._fiff.utils import _read_segments_file
28
25
  from mne.io import BaseRaw
29
- from mne_bids import BIDSPath
26
+ from mne_bids import BIDSPath, find_matching_paths
30
27
 
31
28
  from braindecode.datasets import BaseDataset
32
29
 
@@ -37,10 +34,26 @@ from .paths import get_default_cache_dir
37
34
 
38
35
 
39
36
  class EEGDashBaseDataset(BaseDataset):
40
- """A single EEG recording hosted on AWS S3 and cached locally upon first access.
37
+ """A single EEG recording dataset.
38
+
39
+ Represents a single EEG recording, typically hosted on a remote server (like AWS S3)
40
+ and cached locally upon first access. This class is a subclass of
41
+ :class:`braindecode.datasets.BaseDataset` and can be used with braindecode's
42
+ preprocessing and training pipelines.
43
+
44
+ Parameters
45
+ ----------
46
+ record : dict
47
+ A fully resolved metadata record for the data to load.
48
+ cache_dir : str
49
+ The local directory where the data will be cached.
50
+ s3_bucket : str, optional
51
+ The S3 bucket to download data from. If not provided, defaults to the
52
+ OpenNeuro bucket.
53
+ **kwargs
54
+ Additional keyword arguments passed to the
55
+ :class:`braindecode.datasets.BaseDataset` constructor.
41
56
 
42
- This is a subclass of braindecode's BaseDataset, which can consequently be used in
43
- conjunction with the preprocessing and training pipelines of braindecode.
44
57
  """
45
58
 
46
59
  _AWS_BUCKET = "s3://openneuro.org"
@@ -52,20 +65,6 @@ class EEGDashBaseDataset(BaseDataset):
52
65
  s3_bucket: str | None = None,
53
66
  **kwargs,
54
67
  ):
55
- """Create a new EEGDashBaseDataset instance. Users do not usually need to call this
56
- directly -- instead use the EEGDashDataset class to load a collection of these
57
- recordings from a local BIDS folder or using a database query.
58
-
59
- Parameters
60
- ----------
61
- record : dict
62
- A fully resolved metadata record for the data to load.
63
- cache_dir : str
64
- A local directory where the data will be cached.
65
- kwargs : dict
66
- Additional keyword arguments to pass to the BaseDataset constructor.
67
-
68
- """
69
68
  super().__init__(None, **kwargs)
70
69
  self.record = record
71
70
  self.cache_dir = Path(cache_dir)
@@ -121,14 +120,12 @@ class EEGDashBaseDataset(BaseDataset):
121
120
  self._raw = None
122
121
 
123
122
  def _get_raw_bids_args(self) -> dict[str, Any]:
124
- """Helper to restrict the metadata record to the fields needed to locate a BIDS
125
- recording.
126
- """
123
+ """Extract BIDS-related arguments from the metadata record."""
127
124
  desired_fields = ["subject", "session", "task", "run"]
128
125
  return {k: self.record[k] for k in desired_fields if self.record[k]}
129
126
 
130
127
  def _ensure_raw(self) -> None:
131
- """Download the S3 file and BIDS dependencies if not already cached."""
128
+ """Ensure the raw data file and its dependencies are cached locally."""
132
129
  # TO-DO: remove this once is fixed on the our side
133
130
  # for the competition
134
131
  if not self.s3_open_neuro:
@@ -190,42 +187,53 @@ class EEGDashBaseDataset(BaseDataset):
190
187
  return len(self._raw)
191
188
 
192
189
  @property
193
- def raw(self):
194
- """Return the MNE Raw object for this recording. This will perform the actual
195
- retrieval if not yet done so.
190
+ def raw(self) -> BaseRaw:
191
+ """The MNE Raw object for this recording.
192
+
193
+ Accessing this property triggers the download and caching of the data
194
+ if it has not been accessed before.
195
+
196
+ Returns
197
+ -------
198
+ mne.io.BaseRaw
199
+ The loaded MNE Raw object.
200
+
196
201
  """
197
202
  if self._raw is None:
198
203
  self._ensure_raw()
199
204
  return self._raw
200
205
 
201
206
  @raw.setter
202
- def raw(self, raw):
207
+ def raw(self, raw: BaseRaw):
203
208
  self._raw = raw
204
209
 
205
210
 
206
211
  class EEGDashBaseRaw(BaseRaw):
207
- """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
208
- (when _read_segment is called) and caches it locally. Currently for internal use.
212
+ """MNE BaseRaw wrapper for automatic S3 data fetching.
213
+
214
+ This class extends :class:`mne.io.BaseRaw` to automatically fetch data
215
+ from an S3 bucket and cache it locally when data is first accessed.
216
+ It is intended for internal use within the EEGDash ecosystem.
209
217
 
210
218
  Parameters
211
219
  ----------
212
- input_fname : path-like
213
- Path to the S3 file
220
+ input_fname : str
221
+ The path to the file on the S3 bucket (relative to the bucket root).
214
222
  metadata : dict
215
- The metadata record for the recording (e.g., from the database).
216
- preload : bool
217
- Whether to pre-loaded the data before the first access.
218
- cache_dir : str
219
- Local path under which the data will be cached.
220
- bids_dependencies : list
221
- List of additional BIDS metadata files that should be downloaded and cached
222
- alongside the main recording file.
223
- verbose : str | int | None
224
- Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
223
+ The metadata record for the recording, containing information like
224
+ sampling frequency, channel names, etc.
225
+ preload : bool, default False
226
+ If True, preload the data into memory.
227
+ cache_dir : str, optional
228
+ Local directory for caching data. If None, a default directory is used.
229
+ bids_dependencies : list of str, default []
230
+ A list of BIDS metadata files to download alongside the main recording.
231
+ verbose : str, int, or None, default None
232
+ The MNE verbosity level.
225
233
 
226
234
  See Also
227
235
  --------
228
- mne.io.Raw : Documentation of attributes and methods.
236
+ mne.io.Raw : The base class for Raw objects in MNE.
229
237
 
230
238
  """
231
239
 
@@ -241,7 +249,6 @@ class EEGDashBaseRaw(BaseRaw):
241
249
  bids_dependencies: list[str] = [],
242
250
  verbose: Any = None,
243
251
  ):
244
- """Get to work with S3 endpoint first, no caching"""
245
252
  # Create a simple RawArray
246
253
  sfreq = metadata["sfreq"] # Sampling frequency
247
254
  n_times = metadata["n_times"]
@@ -277,6 +284,7 @@ class EEGDashBaseRaw(BaseRaw):
277
284
  def _read_segment(
278
285
  self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
279
286
  ):
287
+ """Read a segment of data, downloading if necessary."""
280
288
  if not os.path.exists(self.filecache): # not preload
281
289
  if self.bids_dependencies: # this is use only to sidecars for now
282
290
  downloader.download_dependencies(
@@ -297,22 +305,23 @@ class EEGDashBaseRaw(BaseRaw):
297
305
  return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
298
306
 
299
307
  def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
300
- """Read a chunk of data from the file."""
308
+ """Read a chunk of data from a local file."""
301
309
  _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
302
310
 
303
311
 
304
312
  class EEGBIDSDataset:
305
- """A one-stop shop interface to a local BIDS dataset containing EEG recordings.
313
+ """An interface to a local BIDS dataset containing EEG recordings.
306
314
 
307
- This is mainly tailored to the needs of EEGDash application and is used to centralize
308
- interactions with the BIDS dataset, such as parsing the metadata.
315
+ This class centralizes interactions with a BIDS dataset on the local
316
+ filesystem, providing methods to parse metadata, find files, and
317
+ retrieve BIDS-related information.
309
318
 
310
319
  Parameters
311
320
  ----------
312
- data_dir : str | Path
321
+ data_dir : str or Path
313
322
  The path to the local BIDS dataset directory.
314
323
  dataset : str
315
- A name for the dataset.
324
+ A name for the dataset (e.g., "ds002718").
316
325
 
317
326
  """
318
327
 
@@ -338,8 +347,11 @@ class EEGBIDSDataset:
338
347
  ):
339
348
  if data_dir is None or not os.path.exists(data_dir):
340
349
  raise ValueError("data_dir must be specified and must exist")
350
+
341
351
  self.bidsdir = Path(data_dir)
342
352
  self.dataset = dataset
353
+ self.data_dir = data_dir
354
+
343
355
  # Accept exact dataset folder or a variant with informative suffixes
344
356
  # (e.g., dsXXXXX-bdf, dsXXXXX-bdf-mini) to avoid collisions.
345
357
  dir_name = self.bidsdir.name
@@ -347,331 +359,376 @@ class EEGBIDSDataset:
347
359
  raise AssertionError(
348
360
  f"BIDS directory '{dir_name}' does not correspond to dataset '{self.dataset}'"
349
361
  )
350
- self.layout = BIDSLayout(data_dir)
362
+
363
+ # Initialize BIDS paths using fast mne_bids approach instead of pybids
364
+ self._init_bids_paths()
351
365
 
352
366
  # get all recording files in the bids directory
353
- self.files = self._get_recordings(self.layout)
354
367
  assert len(self.files) > 0, ValueError(
355
368
  "Unable to construct EEG dataset. No EEG recordings found."
356
369
  )
357
370
  assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
358
371
 
359
372
  def check_eeg_dataset(self) -> bool:
360
- """Check if the dataset is EEG."""
373
+ """Check if the BIDS dataset contains EEG data.
374
+
375
+ Returns
376
+ -------
377
+ bool
378
+ True if the dataset's modality is EEG, False otherwise.
379
+
380
+ """
361
381
  return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
362
382
 
363
- def _get_recordings(self, layout: BIDSLayout) -> list[str]:
364
- """Get a list of all EEG recording files in the BIDS layout."""
365
- files = []
366
- for ext, exts in self.RAW_EXTENSIONS.items():
367
- files = layout.get(extension=ext, return_type="filename")
368
- if files:
383
+ def _init_bids_paths(self) -> None:
384
+ """Initialize BIDS file paths using mne_bids for fast discovery.
385
+
386
+ Uses mne_bids.find_matching_paths() for efficient pattern-based file
387
+ discovery instead of heavy pybids BIDSLayout indexing.
388
+ """
389
+ # Initialize cache for BIDSPath objects
390
+ self._bids_path_cache = {}
391
+
392
+ # Find all EEG recordings using pattern matching (fast!)
393
+ self.files = []
394
+ for ext in self.RAW_EXTENSIONS.keys():
395
+ # find_matching_paths returns BIDSPath objects
396
+ paths = find_matching_paths(self.bidsdir, datatypes="eeg", extensions=ext)
397
+ if paths:
398
+ # Convert BIDSPath objects to filename strings
399
+ self.files = [str(p.fpath) for p in paths]
369
400
  break
370
- return files
371
401
 
372
- def _get_relative_bidspath(self, filename: str) -> str:
373
- """Make the given file path relative to the BIDS directory."""
374
- bids_parent_dir = self.bidsdir.parent.absolute()
375
- return str(Path(filename).relative_to(bids_parent_dir))
402
+ def _get_bids_path_from_file(self, data_filepath: str):
403
+ """Get a BIDSPath object for a data file with caching.
376
404
 
377
- def _get_property_from_filename(self, property: str, filename: str) -> str:
378
- """Parse a property out of a BIDS-compliant filename. Returns an empty string
379
- if not found.
380
- """
381
- import platform
405
+ Parameters
406
+ ----------
407
+ data_filepath : str
408
+ The path to the data file.
382
409
 
383
- if platform.system() == "Windows":
384
- lookup = re.search(rf"{property}-(.*?)[_\\]", filename)
385
- else:
386
- lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
387
- return lookup.group(1) if lookup else ""
410
+ Returns
411
+ -------
412
+ BIDSPath
413
+ The BIDSPath object for the file.
388
414
 
389
- def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
390
- """Internal helper to merge list of json files found by get_bids_file_inheritance,
391
- expecting the order (from left to right) is from lowest
392
- level to highest level, and return a merged dictionary
393
415
  """
394
- json_files.reverse()
395
- json_dict = {}
396
- for f in json_files:
397
- json_dict.update(json.load(open(f))) # FIXME: should close file
398
- return json_dict
416
+ if data_filepath not in self._bids_path_cache:
417
+ # Parse the filename to extract BIDS entities
418
+ filepath = Path(data_filepath)
419
+ filename = filepath.name
420
+
421
+ # Extract entities from filename using BIDS pattern
422
+ # Expected format: sub-<label>[_ses-<label>][_task-<label>][_run-<label>]_eeg.<ext>
423
+ subject = re.search(r"sub-([^_]*)", filename)
424
+ session = re.search(r"ses-([^_]*)", filename)
425
+ task = re.search(r"task-([^_]*)", filename)
426
+ run = re.search(r"run-([^_]*)", filename)
427
+
428
+ bids_path = BIDSPath(
429
+ subject=subject.group(1) if subject else None,
430
+ session=session.group(1) if session else None,
431
+ task=task.group(1) if task else None,
432
+ run=int(run.group(1)) if run else None,
433
+ datatype="eeg",
434
+ extension=filepath.suffix,
435
+ root=self.bidsdir,
436
+ )
437
+ self._bids_path_cache[data_filepath] = bids_path
399
438
 
400
- def _get_bids_file_inheritance(
401
- self, path: str | Path, basename: str, extension: str
402
- ) -> list[Path]:
403
- """Get all file paths that apply to the basename file in the specified directory
404
- and that end with the specified suffix, recursively searching parent directories
405
- (following the BIDS inheritance principle in the order of lowest level first).
439
+ return self._bids_path_cache[data_filepath]
440
+
441
+ def _get_json_with_inheritance(
442
+ self, data_filepath: str, json_filename: str
443
+ ) -> dict:
444
+ """Get JSON metadata with BIDS inheritance handling.
445
+
446
+ Walks up the directory tree to find and merge JSON files following
447
+ BIDS inheritance principles.
406
448
 
407
449
  Parameters
408
450
  ----------
409
- path : str | Path
410
- The directory path to search for files.
411
- basename : str
412
- BIDS file basename without _eeg.set extension for example
413
- extension : str
414
- Only consider files that end with the specified suffix; e.g. channels.tsv
451
+ data_filepath : str
452
+ The path to the data file.
453
+ json_filename : str
454
+ The name of the JSON file to find (e.g., "eeg.json").
415
455
 
416
456
  Returns
417
457
  -------
418
- list[Path]
419
- A list of file paths that match the given basename and extension.
458
+ dict
459
+ The merged JSON metadata.
420
460
 
421
461
  """
462
+ json_dict = {}
463
+ current_dir = Path(data_filepath).parent
464
+ root_dir = self.bidsdir
465
+
466
+ # Walk up from file directory to root, collecting JSON files
467
+ while current_dir >= root_dir:
468
+ json_path = current_dir / json_filename
469
+ if json_path.exists():
470
+ with open(json_path) as f:
471
+ json_dict.update(json.load(f))
472
+
473
+ # Stop at BIDS root (contains dataset_description.json)
474
+ if (current_dir / "dataset_description.json").exists():
475
+ break
476
+
477
+ current_dir = current_dir.parent
478
+
479
+ return json_dict
480
+
481
+ def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
482
+ """Merge a list of JSON files according to BIDS inheritance."""
483
+ json_files.reverse()
484
+ json_dict = {}
485
+ for f in json_files:
486
+ with open(f) as fp:
487
+ json_dict.update(json.load(fp))
488
+ return json_dict
489
+
490
+ def _get_bids_file_inheritance(
491
+ self, path: str | Path, basename: str, extension: str
492
+ ) -> list[Path]:
493
+ """Find all applicable metadata files using BIDS inheritance."""
422
494
  top_level_files = ["README", "dataset_description.json", "participants.tsv"]
423
495
  bids_files = []
424
496
 
425
- # check if path is str object
426
497
  if isinstance(path, str):
427
498
  path = Path(path)
428
- if not path.exists:
429
- raise ValueError("path {path} does not exist")
499
+ if not path.exists():
500
+ raise ValueError(f"path {path} does not exist")
430
501
 
431
- # check if file is in current path
432
502
  for file in os.listdir(path):
433
- # target_file = path / f"{cur_file_basename}_{extension}"
434
- if os.path.isfile(path / file):
435
- # check if file has extension extension
436
- # check if file basename has extension
437
- if file.endswith(extension):
438
- filepath = path / file
439
- bids_files.append(filepath)
440
-
441
- # check if file is in top level directory
503
+ if os.path.isfile(path / file) and file.endswith(extension):
504
+ bids_files.append(path / file)
505
+
442
506
  if any(file in os.listdir(path) for file in top_level_files):
443
507
  return bids_files
444
508
  else:
445
- # call get_bids_file_inheritance recursively with parent directory
446
509
  bids_files.extend(
447
510
  self._get_bids_file_inheritance(path.parent, basename, extension)
448
511
  )
449
512
  return bids_files
450
513
 
451
514
  def get_bids_metadata_files(
452
- self, filepath: str | Path, metadata_file_extension: list[str]
515
+ self, filepath: str | Path, metadata_file_extension: str
453
516
  ) -> list[Path]:
454
- """Retrieve all metadata file paths that apply to a given data file path and that
455
- end with a specific suffix (following the BIDS inheritance principle).
517
+ """Retrieve all metadata files that apply to a given data file.
518
+
519
+ Follows the BIDS inheritance principle to find all relevant metadata
520
+ files (e.g., ``channels.tsv``, ``eeg.json``) for a specific recording.
456
521
 
457
522
  Parameters
458
523
  ----------
459
- filepath: str | Path
460
- The filepath to get the associated metadata files for.
524
+ filepath : str or Path
525
+ The path to the data file.
461
526
  metadata_file_extension : str
462
- Consider only metadata files that end with the specified suffix,
463
- e.g., channels.tsv or eeg.json
527
+ The extension of the metadata file to search for (e.g., "channels.tsv").
464
528
 
465
529
  Returns
466
530
  -------
467
- list[Path]:
468
- A list of filepaths for all matching metadata files
531
+ list of Path
532
+ A list of paths to the matching metadata files.
469
533
 
470
534
  """
471
535
  if isinstance(filepath, str):
472
536
  filepath = Path(filepath)
473
- if not filepath.exists:
474
- raise ValueError("filepath {filepath} does not exist")
537
+ if not filepath.exists():
538
+ raise ValueError(f"filepath {filepath} does not exist")
475
539
  path, filename = os.path.split(filepath)
476
540
  basename = filename[: filename.rfind("_")]
477
- # metadata files
478
541
  meta_files = self._get_bids_file_inheritance(
479
542
  path, basename, metadata_file_extension
480
543
  )
481
544
  return meta_files
482
545
 
483
- def _scan_directory(self, directory: str, extension: str) -> list[Path]:
484
- """Return a list of file paths that end with the given extension in the specified
485
- directory. Ignores certain special directories like .git, .datalad, derivatives,
486
- and code.
546
+ def get_files(self) -> list[str]:
547
+ """Get all EEG recording file paths in the BIDS dataset.
548
+
549
+ Returns
550
+ -------
551
+ list of str
552
+ A list of file paths for all valid EEG recordings.
553
+
487
554
  """
488
- result_files = []
489
- directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
490
- with os.scandir(directory) as entries:
491
- for entry in entries:
492
- if entry.is_file() and entry.name.endswith(extension):
493
- result_files.append(entry.path)
494
- elif entry.is_dir():
495
- # check that entry path doesn't contain any name in ignore list
496
- if not any(name in entry.name for name in directory_to_ignore):
497
- result_files.append(entry.path) # Add directory to scan later
498
- return result_files
499
-
500
- def _get_files_with_extension_parallel(
501
- self, directory: str, extension: str = ".set", max_workers: int = -1
502
- ) -> list[Path]:
503
- """Efficiently scan a directory and its subdirectories for files that end with
504
- the given extension.
555
+ return self.files
556
+
557
+ def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
558
+ """Retrieve a specific attribute from BIDS metadata.
505
559
 
506
560
  Parameters
507
561
  ----------
508
- directory : str
509
- The root directory to scan for files.
510
- extension : str
511
- Only consider files that end with this suffix, e.g. '.set'.
512
- max_workers : int
513
- Optionally specify the maximum number of worker threads to use for parallel scanning.
514
- Defaults to all available CPU cores if set to -1.
562
+ attribute : str
563
+ The name of the attribute to retrieve (e.g., "sfreq", "subject").
564
+ data_filepath : str
565
+ The path to the data file.
515
566
 
516
567
  Returns
517
568
  -------
518
- list[Path]:
519
- A list of filepaths for all matching metadata files
569
+ Any
570
+ The value of the requested attribute, or None if not found.
520
571
 
521
572
  """
522
- result_files = []
523
- dirs_to_scan = [directory]
573
+ bids_path = self._get_bids_path_from_file(data_filepath)
574
+
575
+ # Direct BIDSPath properties for entities
576
+ direct_attrs = {
577
+ "subject": bids_path.subject,
578
+ "session": bids_path.session,
579
+ "task": bids_path.task,
580
+ "run": bids_path.run,
581
+ "modality": bids_path.datatype,
582
+ }
524
583
 
525
- # Use joblib.Parallel and delayed to parallelize directory scanning
526
- while dirs_to_scan:
527
- logger.info(
528
- f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
529
- )
530
- # Run the scan_directory function in parallel across directories
531
- results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
532
- delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
533
- )
584
+ if attribute in direct_attrs:
585
+ return direct_attrs[attribute]
534
586
 
535
- # Reset the directories to scan and process the results
536
- dirs_to_scan = []
537
- for res in results:
538
- for path in res:
539
- if os.path.isdir(path):
540
- dirs_to_scan.append(path) # Queue up subdirectories to scan
541
- else:
542
- result_files.append(path) # Add files to the final result
543
- logger.info(f"Found {len(result_files)} files.")
544
-
545
- return result_files
546
-
547
- def load_and_preprocess_raw(
548
- self, raw_file: str, preprocess: bool = False
549
- ) -> np.ndarray:
550
- """Utility function to load a raw data file with MNE and apply some simple
551
- (hardcoded) preprocessing and return as a numpy array. Not meant for purposes
552
- other than testing or debugging.
553
- """
554
- logger.info(f"Loading raw data from {raw_file}")
555
- EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
556
-
557
- if preprocess:
558
- # highpass filter
559
- EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
560
- # remove 60Hz line noise
561
- EEG = EEG.notch_filter(freqs=(60), verbose=False)
562
- # bring to common sampling rate
563
- sfreq = 128
564
- if EEG.info["sfreq"] != sfreq:
565
- EEG = EEG.resample(sfreq)
566
-
567
- mat_data = EEG.get_data()
568
-
569
- if len(mat_data.shape) > 2:
570
- raise ValueError("Expect raw data to be CxT dimension")
571
- return mat_data
572
-
573
- def get_files(self) -> list[Path]:
574
- """Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
575
- return self.files
587
+ # For JSON-based attributes, read and cache eeg.json
588
+ eeg_json = self._get_json_with_inheritance(data_filepath, "eeg.json")
589
+ json_attrs = {
590
+ "sfreq": eeg_json.get("SamplingFrequency"),
591
+ "ntimes": eeg_json.get("RecordingDuration"),
592
+ "nchans": eeg_json.get("EEGChannelCount"),
593
+ }
594
+
595
+ return json_attrs.get(attribute)
576
596
 
577
- def resolve_bids_json(self, json_files: list[str]) -> dict:
578
- """Resolve the BIDS JSON files and return a dictionary of the resolved values.
597
+ def channel_labels(self, data_filepath: str) -> list[str]:
598
+ """Get a list of channel labels from channels.tsv.
579
599
 
580
600
  Parameters
581
601
  ----------
582
- json_files : list
583
- A list of JSON file paths to resolve in order of leaf level first.
602
+ data_filepath : str
603
+ The path to the data file.
584
604
 
585
605
  Returns
586
606
  -------
587
- dict: A dictionary of the resolved values.
607
+ list of str
608
+ A list of channel names.
588
609
 
589
610
  """
590
- if len(json_files) == 0:
591
- raise ValueError("No JSON files provided")
592
- json_files.reverse() # TODO undeterministic
611
+ # Find channels.tsv in the same directory as the data file
612
+ # It can be named either "channels.tsv" or "*_channels.tsv"
613
+ filepath = Path(data_filepath)
614
+ parent_dir = filepath.parent
615
+
616
+ # Try the standard channels.tsv first
617
+ channels_tsv_path = parent_dir / "channels.tsv"
618
+ if not channels_tsv_path.exists():
619
+ # Try to find *_channels.tsv matching the filename prefix
620
+ base_name = filepath.stem # filename without extension
621
+ for tsv_file in parent_dir.glob("*_channels.tsv"):
622
+ # Check if it matches by looking at task/run components
623
+ tsv_name = tsv_file.stem.replace("_channels", "")
624
+ if base_name.startswith(tsv_name):
625
+ channels_tsv_path = tsv_file
626
+ break
627
+
628
+ if not channels_tsv_path.exists():
629
+ raise FileNotFoundError(f"No channels.tsv found for {data_filepath}")
630
+
631
+ channels_tsv = pd.read_csv(channels_tsv_path, sep="\t")
632
+ return channels_tsv["name"].tolist()
593
633
 
594
- json_dict = {}
595
- for json_file in json_files:
596
- with open(json_file) as f:
597
- json_dict.update(json.load(f))
598
- return json_dict
634
+ def channel_types(self, data_filepath: str) -> list[str]:
635
+ """Get a list of channel types from channels.tsv.
599
636
 
600
- def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
601
- """Retrieve a specific attribute from the BIDS file metadata applicable
602
- to the provided recording file path.
603
- """
604
- entities = self.layout.parse_file_entities(data_filepath)
605
- bidsfile = self.layout.get(**entities)[0]
606
- attributes = bidsfile.get_entities(metadata="all")
607
- attribute_mapping = {
608
- "sfreq": "SamplingFrequency",
609
- "modality": "datatype",
610
- "task": "task",
611
- "session": "session",
612
- "run": "run",
613
- "subject": "subject",
614
- "ntimes": "RecordingDuration",
615
- "nchans": "EEGChannelCount",
616
- }
617
- attribute_value = attributes.get(attribute_mapping.get(attribute), None)
618
- return attribute_value
637
+ Parameters
638
+ ----------
639
+ data_filepath : str
640
+ The path to the data file.
619
641
 
620
- def channel_labels(self, data_filepath: str) -> list[str]:
621
- """Get a list of channel labels for the given data file path."""
622
- channels_tsv = pd.read_csv(
623
- self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
624
- )
625
- return channels_tsv["name"].tolist()
642
+ Returns
643
+ -------
644
+ list of str
645
+ A list of channel types.
626
646
 
627
- def channel_types(self, data_filepath: str) -> list[str]:
628
- """Get a list of channel types for the given data file path."""
629
- channels_tsv = pd.read_csv(
630
- self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
631
- )
647
+ """
648
+ # Find channels.tsv in the same directory as the data file
649
+ # It can be named either "channels.tsv" or "*_channels.tsv"
650
+ filepath = Path(data_filepath)
651
+ parent_dir = filepath.parent
652
+
653
+ # Try the standard channels.tsv first
654
+ channels_tsv_path = parent_dir / "channels.tsv"
655
+ if not channels_tsv_path.exists():
656
+ # Try to find *_channels.tsv matching the filename prefix
657
+ base_name = filepath.stem # filename without extension
658
+ for tsv_file in parent_dir.glob("*_channels.tsv"):
659
+ # Check if it matches by looking at task/run components
660
+ tsv_name = tsv_file.stem.replace("_channels", "")
661
+ if base_name.startswith(tsv_name):
662
+ channels_tsv_path = tsv_file
663
+ break
664
+
665
+ if not channels_tsv_path.exists():
666
+ raise FileNotFoundError(f"No channels.tsv found for {data_filepath}")
667
+
668
+ channels_tsv = pd.read_csv(channels_tsv_path, sep="\t")
632
669
  return channels_tsv["type"].tolist()
633
670
 
634
671
  def num_times(self, data_filepath: str) -> int:
635
- """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
636
- eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
637
- eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
672
+ """Get the number of time points in the recording.
673
+
674
+ Calculated from ``SamplingFrequency`` and ``RecordingDuration`` in eeg.json.
675
+
676
+ Parameters
677
+ ----------
678
+ data_filepath : str
679
+ The path to the data file.
680
+
681
+ Returns
682
+ -------
683
+ int
684
+ The approximate number of time points.
685
+
686
+ """
687
+ eeg_json_dict = self._get_json_with_inheritance(data_filepath, "eeg.json")
638
688
  return int(
639
- eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
689
+ eeg_json_dict.get("SamplingFrequency", 0)
690
+ * eeg_json_dict.get("RecordingDuration", 0)
640
691
  )
641
692
 
642
693
  def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
643
- """Get BIDS participants.tsv record for the subject to which the given file
644
- path corresponds, as a dictionary.
694
+ """Get the participants.tsv record for a subject.
695
+
696
+ Parameters
697
+ ----------
698
+ data_filepath : str
699
+ The path to a data file belonging to the subject.
700
+
701
+ Returns
702
+ -------
703
+ dict
704
+ A dictionary of the subject's information from participants.tsv.
705
+
645
706
  """
646
- participants_tsv = pd.read_csv(
647
- self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
648
- )
649
- # if participants_tsv is not empty
707
+ participants_tsv_path = self.get_bids_metadata_files(
708
+ data_filepath, "participants.tsv"
709
+ )[0]
710
+ participants_tsv = pd.read_csv(participants_tsv_path, sep="\t")
650
711
  if participants_tsv.empty:
651
712
  return {}
652
- # set 'participant_id' as index
653
713
  participants_tsv.set_index("participant_id", inplace=True)
654
714
  subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
655
715
  return participants_tsv.loc[subject].to_dict()
656
716
 
657
717
  def eeg_json(self, data_filepath: str) -> dict[str, Any]:
658
- """Get BIDS eeg.json metadata for the given data file path."""
659
- eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
660
- eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
661
- return eeg_json_dict
662
-
663
- def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
664
- """Get BIDS channels.tsv metadata for the given data file path, as a dictionary
665
- of lists and/or single values.
718
+ """Get the merged eeg.json metadata for a data file.
719
+
720
+ Parameters
721
+ ----------
722
+ data_filepath : str
723
+ The path to the data file.
724
+
725
+ Returns
726
+ -------
727
+ dict
728
+ The merged eeg.json metadata.
729
+
666
730
  """
667
- channels_tsv = pd.read_csv(
668
- self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
669
- )
670
- channel_tsv = channels_tsv.to_dict()
671
- # 'name' and 'type' now have a dictionary of index-value. Convert them to list
672
- for list_field in ["name", "type", "units"]:
673
- channel_tsv[list_field] = list(channel_tsv[list_field].values())
674
- return channel_tsv
731
+ return self._get_json_with_inheritance(data_filepath, "eeg.json")
675
732
 
676
733
 
677
734
  __all__ = ["EEGDashBaseDataset", "EEGBIDSDataset", "EEGDashBaseRaw"]