eegdash 0.3.9.dev170082126__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_utils.py CHANGED
@@ -1,10 +1,19 @@
1
+ # Authors: The EEGDash contributors.
2
+ # License: GNU General Public License
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """Data utilities and dataset classes for EEG data handling.
6
+
7
+ This module provides core dataset classes for working with EEG data in the EEGDash ecosystem,
8
+ including classes for individual recordings and collections of datasets. It integrates with
9
+ braindecode for machine learning workflows and handles data loading from both local and remote sources.
10
+ """
11
+
1
12
  import io
2
13
  import json
3
- import logging
4
14
  import os
5
15
  import re
6
16
  import traceback
7
- import warnings
8
17
  from contextlib import redirect_stderr
9
18
  from pathlib import Path
10
19
  from typing import Any
@@ -13,26 +22,41 @@ import mne
13
22
  import mne_bids
14
23
  import numpy as np
15
24
  import pandas as pd
16
- import s3fs
17
- from bids import BIDSLayout
18
- from fsspec.callbacks import TqdmCallback
19
25
  from joblib import Parallel, delayed
20
26
  from mne._fiff.utils import _read_segments_file
21
27
  from mne.io import BaseRaw
28
+ from mne.utils.check import _soft_import
22
29
  from mne_bids import BIDSPath
23
30
 
24
31
  from braindecode.datasets import BaseDataset
25
32
 
33
+ from . import downloader
34
+ from .bids_eeg_metadata import enrich_from_participants
35
+ from .logging import logger
26
36
  from .paths import get_default_cache_dir
27
37
 
28
- logger = logging.getLogger("eegdash")
29
-
30
38
 
31
39
  class EEGDashBaseDataset(BaseDataset):
32
- """A single EEG recording hosted on AWS S3 and cached locally upon first access.
40
+ """A single EEG recording dataset.
41
+
42
+ Represents a single EEG recording, typically hosted on a remote server (like AWS S3)
43
+ and cached locally upon first access. This class is a subclass of
44
+ :class:`braindecode.datasets.BaseDataset` and can be used with braindecode's
45
+ preprocessing and training pipelines.
46
+
47
+ Parameters
48
+ ----------
49
+ record : dict
50
+ A fully resolved metadata record for the data to load.
51
+ cache_dir : str
52
+ The local directory where the data will be cached.
53
+ s3_bucket : str, optional
54
+ The S3 bucket to download data from. If not provided, defaults to the
55
+ OpenNeuro bucket.
56
+ **kwargs
57
+ Additional keyword arguments passed to the
58
+ :class:`braindecode.datasets.BaseDataset` constructor.
33
59
 
34
- This is a subclass of braindecode's BaseDataset, which can consequently be used in
35
- conjunction with the preprocessing and training pipelines of braindecode.
36
60
  """
37
61
 
38
62
  _AWS_BUCKET = "s3://openneuro.org"
@@ -44,20 +68,6 @@ class EEGDashBaseDataset(BaseDataset):
44
68
  s3_bucket: str | None = None,
45
69
  **kwargs,
46
70
  ):
47
- """Create a new EEGDashBaseDataset instance. Users do not usually need to call this
48
- directly -- instead use the EEGDashDataset class to load a collection of these
49
- recordings from a local BIDS folder or using a database query.
50
-
51
- Parameters
52
- ----------
53
- record : dict
54
- A fully resolved metadata record for the data to load.
55
- cache_dir : str
56
- A local directory where the data will be cached.
57
- kwargs : dict
58
- Additional keyword arguments to pass to the BaseDataset constructor.
59
-
60
- """
61
71
  super().__init__(None, **kwargs)
62
72
  self.record = record
63
73
  self.cache_dir = Path(cache_dir)
@@ -73,6 +83,7 @@ class EEGDashBaseDataset(BaseDataset):
73
83
  # Compute a dataset folder name under cache_dir that encodes preprocessing
74
84
  # (e.g., bdf, mini) to avoid overlapping with the original dataset cache.
75
85
  self.dataset_folder = record.get("dataset", "")
86
+ # TODO: remove this hack when competition is over
76
87
  if s3_bucket:
77
88
  suffixes: list[str] = []
78
89
  bucket_lower = str(s3_bucket).lower()
@@ -91,6 +102,7 @@ class EEGDashBaseDataset(BaseDataset):
91
102
  rel = Path(self.dataset_folder) / rel
92
103
  self.filecache = self.cache_dir / rel
93
104
  self.bids_root = self.cache_dir / self.dataset_folder
105
+
94
106
  self.bidspath = BIDSPath(
95
107
  root=self.bids_root,
96
108
  datatype="eeg",
@@ -98,122 +110,25 @@ class EEGDashBaseDataset(BaseDataset):
98
110
  **self.bids_kwargs,
99
111
  )
100
112
 
101
- self.s3file = self._get_s3path(record["bidspath"])
113
+ self.s3file = downloader.get_s3path(self.s3_bucket, record["bidspath"])
102
114
  self.bids_dependencies = record["bidsdependencies"]
103
- # Temporary fix for BIDS dependencies path
104
- # just to release to the competition
115
+ self.bids_dependencies_original = record["bidsdependencies"]
116
+ # TODO: removing temporary fix for BIDS dependencies path
117
+ # when the competition is over and dataset is digested properly
105
118
  if not self.s3_open_neuro:
106
- self.bids_dependencies_original = self.bids_dependencies
107
119
  self.bids_dependencies = [
108
120
  dep.split("/", 1)[1] for dep in self.bids_dependencies
109
121
  ]
110
122
 
111
123
  self._raw = None
112
124
 
113
- def _get_s3path(self, filepath: str) -> str:
114
- """Helper to form an AWS S3 URI for the given relative filepath."""
115
- return f"{self.s3_bucket}/{filepath}"
116
-
117
- def _download_s3(self) -> None:
118
- """Download function that gets the raw EEG data from S3."""
119
- filesystem = s3fs.S3FileSystem(
120
- anon=True, client_kwargs={"region_name": "us-east-2"}
121
- )
122
- if not self.s3_open_neuro:
123
- self.s3file = re.sub(r"(^|/)ds\d{6}/", r"\1", self.s3file, count=1)
124
- if self.s3file.endswith(".set"):
125
- self.s3file = self.s3file[:-4] + ".bdf"
126
- self.filecache = self.filecache.with_suffix(".bdf")
127
-
128
- self.filecache.parent.mkdir(parents=True, exist_ok=True)
129
- info = filesystem.info(self.s3file)
130
- size = info.get("size") or info.get("Size")
131
-
132
- callback = TqdmCallback(
133
- size=size,
134
- tqdm_kwargs=dict(
135
- desc=f"Downloading {Path(self.s3file).name}",
136
- unit="B",
137
- unit_scale=True,
138
- unit_divisor=1024,
139
- dynamic_ncols=True,
140
- leave=True,
141
- mininterval=0.2,
142
- smoothing=0.1,
143
- miniters=1,
144
- bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
145
- "[{elapsed}<{remaining}, {rate_fmt}]",
146
- ),
147
- )
148
- filesystem.get(self.s3file, self.filecache, callback=callback)
149
-
150
- self.filenames = [self.filecache]
151
-
152
- def _download_dependencies(self) -> None:
153
- """Download all BIDS dependency files (metadata files, recording sidecar files)
154
- from S3 and cache them locally.
155
- """
156
- filesystem = s3fs.S3FileSystem(
157
- anon=True, client_kwargs={"region_name": "us-east-2"}
158
- )
159
- for i, dep in enumerate(self.bids_dependencies):
160
- if not self.s3_open_neuro:
161
- # fix this when our bucket is integrated into the
162
- # mongodb
163
- # if the file have ".set" replace to ".bdf"
164
- if dep.endswith(".set"):
165
- dep = dep[:-4] + ".bdf"
166
-
167
- s3path = self._get_s3path(dep)
168
- if not self.s3_open_neuro:
169
- dep = self.bids_dependencies_original[i]
170
-
171
- dep_path = Path(dep)
172
- if dep_path.parts and dep_path.parts[0] == self.record.get("dataset"):
173
- dep_local = Path(self.dataset_folder, *dep_path.parts[1:])
174
- else:
175
- dep_local = Path(self.dataset_folder) / dep_path
176
- filepath = self.cache_dir / dep_local
177
- if not self.s3_open_neuro:
178
- if filepath.suffix == ".set":
179
- filepath = filepath.with_suffix(".bdf")
180
- if self.filecache.suffix == ".set":
181
- self.filecache = self.filecache.with_suffix(".bdf")
182
-
183
- # here, we download the dependency and it is fine
184
- # in the case of the competition.
185
- if not filepath.exists():
186
- filepath.parent.mkdir(parents=True, exist_ok=True)
187
- info = filesystem.info(s3path)
188
- size = info.get("size") or info.get("Size")
189
-
190
- callback = TqdmCallback(
191
- size=size,
192
- tqdm_kwargs=dict(
193
- desc=f"Downloading {Path(s3path).name}",
194
- unit="B",
195
- unit_scale=True,
196
- unit_divisor=1024,
197
- dynamic_ncols=True,
198
- leave=True,
199
- mininterval=0.2,
200
- smoothing=0.1,
201
- miniters=1,
202
- bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
203
- "[{elapsed}<{remaining}, {rate_fmt}]",
204
- ),
205
- )
206
- filesystem.get(s3path, filepath, callback=callback)
207
-
208
125
  def _get_raw_bids_args(self) -> dict[str, Any]:
209
- """Helper to restrict the metadata record to the fields needed to locate a BIDS
210
- recording.
211
- """
126
+ """Extract BIDS-related arguments from the metadata record."""
212
127
  desired_fields = ["subject", "session", "task", "run"]
213
128
  return {k: self.record[k] for k in desired_fields if self.record[k]}
214
129
 
215
130
  def _ensure_raw(self) -> None:
216
- """Download the S3 file and BIDS dependencies if not already cached."""
131
+ """Ensure the raw data file and its dependencies are cached locally."""
217
132
  # TO-DO: remove this once is fixed on the our side
218
133
  # for the competition
219
134
  if not self.s3_open_neuro:
@@ -222,130 +137,43 @@ class EEGDashBaseDataset(BaseDataset):
222
137
 
223
138
  if not os.path.exists(self.filecache): # not preload
224
139
  if self.bids_dependencies:
225
- self._download_dependencies()
226
- self._download_s3()
140
+ downloader.download_dependencies(
141
+ s3_bucket=self.s3_bucket,
142
+ bids_dependencies=self.bids_dependencies,
143
+ bids_dependencies_original=self.bids_dependencies_original,
144
+ cache_dir=self.cache_dir,
145
+ dataset_folder=self.dataset_folder,
146
+ record=self.record,
147
+ s3_open_neuro=self.s3_open_neuro,
148
+ )
149
+ self.filecache = downloader.download_s3_file(
150
+ self.s3file, self.filecache, self.s3_open_neuro
151
+ )
152
+ self.filenames = [self.filecache]
227
153
  if self._raw is None:
228
- # capturing any warnings
229
- # to-do: remove this once is fixed on the mne-bids side.
230
- with warnings.catch_warnings(record=True) as w:
231
- # Ensure all warnings are captured into 'w' and not shown to users
232
- warnings.simplefilter("always")
233
- try:
234
- # mne-bids emits RuntimeWarnings to stderr; silence stderr during read
235
- _stderr_buffer = io.StringIO()
236
- with redirect_stderr(_stderr_buffer):
237
- self._raw = mne_bids.read_raw_bids(
238
- bids_path=self.bidspath, verbose="ERROR"
239
- )
240
- # Parse unmapped participants.tsv fields reported by mne-bids and
241
- # inject them into Raw.info and the dataset description generically.
242
- extras = self._extract_unmapped_participants_from_warnings(w)
243
- if extras:
244
- # 1) Attach to Raw.info under subject_info.participants_extras
245
- try:
246
- subject_info = self._raw.info.get("subject_info") or {}
247
- if not isinstance(subject_info, dict):
248
- subject_info = {}
249
- pe = subject_info.get("participants_extras") or {}
250
- if not isinstance(pe, dict):
251
- pe = {}
252
- # Merge without overwriting
253
- for k, v in extras.items():
254
- pe.setdefault(k, v)
255
- subject_info["participants_extras"] = pe
256
- self._raw.info["subject_info"] = subject_info
257
- except Exception:
258
- # Non-fatal; continue
259
- pass
260
-
261
- # 2) Also add to this dataset's description, if possible, so
262
- # targets can be selected later without naming specifics.
263
- try:
264
- if isinstance(self.description, dict):
265
- for k, v in extras.items():
266
- self.description.setdefault(k, v)
267
- elif isinstance(self.description, pd.Series):
268
- for k, v in extras.items():
269
- if k not in self.description.index:
270
- self.description.loc[k] = v
271
- except Exception:
272
- pass
273
- except Exception as e:
274
- logger.error(
275
- f"Error while reading BIDS file: {self.bidspath}\n"
276
- "This may be due to a missing or corrupted file.\n"
277
- "Please check the file and try again."
278
- )
279
- logger.error(f"Exception: {e}")
280
- logger.error(traceback.format_exc())
281
- raise e
282
- # Filter noisy mapping notices from mne-bids; surface others
283
- for captured_warning in w:
284
- try:
285
- msg = str(captured_warning.message)
286
- except Exception:
287
- continue
288
- # Suppress verbose participants mapping messages
289
- if "Unable to map the following column" in msg and "MNE" in msg:
290
- logger.debug(
291
- "Suppressed mne-bids mapping warning while reading BIDS file: %s",
292
- msg,
293
- )
294
- continue
295
-
296
- def _extract_unmapped_participants_from_warnings(
297
- self, warnings_list: list[Any]
298
- ) -> dict[str, Any]:
299
- """Scan captured warnings from mne-bids and extract unmapped participants.tsv
300
- entries in a generic way.
301
-
302
- Optionally, the column name can carry a note in parentheses that we ignore
303
- for key/value extraction. Returns a mapping of column name -> raw value.
304
- """
305
- extras: dict[str, Any] = {}
306
- header = "Unable to map the following column(s) to MNE:"
307
- for wr in warnings_list:
308
- try:
309
- msg = str(wr.message)
310
- except Exception:
311
- continue
312
- if header not in msg:
313
- continue
314
- lines = msg.splitlines()
315
- # Find the header line, then parse subsequent lines as entries
316
154
  try:
317
- idx = next(i for i, ln in enumerate(lines) if header in ln)
318
- except StopIteration:
319
- idx = -1
320
- for line in lines[idx + 1 :]:
321
- line = line.strip()
322
- if not line:
323
- continue
324
- # Pattern: <col>(optional note): <value>
325
- # Examples: "gender: F", "Ethnicity: Indian", "foo (ignored): bar"
326
- m = re.match(r"^([^:]+?)(?:\s*\([^)]*\))?\s*:\s*(.*)$", line)
327
- if not m:
328
- continue
329
- col = m.group(1).strip()
330
- val = m.group(2).strip()
331
- # Keep original column names as provided to stay agnostic
332
- if col and col not in extras:
333
- extras[col] = val
334
- return extras
335
-
336
- # === BaseDataset and PyTorch Dataset interface ===
337
-
338
- def __getitem__(self, index):
339
- """Main function to access a sample from the dataset."""
340
- X = self.raw[:, index][0]
341
- y = None
342
- if self.target_name is not None:
343
- y = self.description[self.target_name]
344
- if isinstance(y, pd.Series):
345
- y = y.to_list()
346
- if self.transform is not None:
347
- X = self.transform(X)
348
- return X, y
155
+ # mne-bids can emit noisy warnings to stderr; keep user logs clean
156
+ _stderr_buffer = io.StringIO()
157
+ with redirect_stderr(_stderr_buffer):
158
+ self._raw = mne_bids.read_raw_bids(
159
+ bids_path=self.bidspath, verbose="ERROR"
160
+ )
161
+ # Enrich Raw.info and description with participants.tsv extras
162
+ enrich_from_participants(
163
+ self.bids_root, self.bidspath, self._raw, self.description
164
+ )
165
+
166
+ except Exception as e:
167
+ logger.error(
168
+ f"Error while reading BIDS file: {self.bidspath}\n"
169
+ "This may be due to a missing or corrupted file.\n"
170
+ "Please check the file and try again.\n"
171
+ "Usually erasing the local cache and re-downloading helps.\n"
172
+ f"`rm {self.bidspath}`"
173
+ )
174
+ logger.error(f"Exception: {e}")
175
+ logger.error(traceback.format_exc())
176
+ raise e
349
177
 
350
178
  def __len__(self) -> int:
351
179
  """Return the number of samples in the dataset."""
@@ -362,42 +190,53 @@ class EEGDashBaseDataset(BaseDataset):
362
190
  return len(self._raw)
363
191
 
364
192
  @property
365
- def raw(self):
366
- """Return the MNE Raw object for this recording. This will perform the actual
367
- retrieval if not yet done so.
193
+ def raw(self) -> BaseRaw:
194
+ """The MNE Raw object for this recording.
195
+
196
+ Accessing this property triggers the download and caching of the data
197
+ if it has not been accessed before.
198
+
199
+ Returns
200
+ -------
201
+ mne.io.BaseRaw
202
+ The loaded MNE Raw object.
203
+
368
204
  """
369
205
  if self._raw is None:
370
206
  self._ensure_raw()
371
207
  return self._raw
372
208
 
373
209
  @raw.setter
374
- def raw(self, raw):
210
+ def raw(self, raw: BaseRaw):
375
211
  self._raw = raw
376
212
 
377
213
 
378
214
  class EEGDashBaseRaw(BaseRaw):
379
- """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
380
- (when _read_segment is called) and caches it locally. Currently for internal use.
215
+ """MNE BaseRaw wrapper for automatic S3 data fetching.
216
+
217
+ This class extends :class:`mne.io.BaseRaw` to automatically fetch data
218
+ from an S3 bucket and cache it locally when data is first accessed.
219
+ It is intended for internal use within the EEGDash ecosystem.
381
220
 
382
221
  Parameters
383
222
  ----------
384
- input_fname : path-like
385
- Path to the S3 file
223
+ input_fname : str
224
+ The path to the file on the S3 bucket (relative to the bucket root).
386
225
  metadata : dict
387
- The metadata record for the recording (e.g., from the database).
388
- preload : bool
389
- Whether to pre-loaded the data before the first access.
390
- cache_dir : str
391
- Local path under which the data will be cached.
392
- bids_dependencies : list
393
- List of additional BIDS metadata files that should be downloaded and cached
394
- alongside the main recording file.
395
- verbose : str | int | None
396
- Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
226
+ The metadata record for the recording, containing information like
227
+ sampling frequency, channel names, etc.
228
+ preload : bool, default False
229
+ If True, preload the data into memory.
230
+ cache_dir : str, optional
231
+ Local directory for caching data. If None, a default directory is used.
232
+ bids_dependencies : list of str, default []
233
+ A list of BIDS metadata files to download alongside the main recording.
234
+ verbose : str, int, or None, default None
235
+ The MNE verbosity level.
397
236
 
398
237
  See Also
399
238
  --------
400
- mne.io.Raw : Documentation of attributes and methods.
239
+ mne.io.Raw : The base class for Raw objects in MNE.
401
240
 
402
241
  """
403
242
 
@@ -413,7 +252,6 @@ class EEGDashBaseRaw(BaseRaw):
413
252
  bids_dependencies: list[str] = [],
414
253
  verbose: Any = None,
415
254
  ):
416
- """Get to work with S3 endpoint first, no caching"""
417
255
  # Create a simple RawArray
418
256
  sfreq = metadata["sfreq"] # Sampling frequency
419
257
  n_times = metadata["n_times"]
@@ -426,13 +264,16 @@ class EEGDashBaseRaw(BaseRaw):
426
264
  ch_types.append(chtype)
427
265
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
428
266
 
429
- self.s3file = self._get_s3path(input_fname)
267
+ self.s3file = downloader.get_s3path(self._AWS_BUCKET, input_fname)
430
268
  self.cache_dir = Path(cache_dir) if cache_dir else get_default_cache_dir()
431
269
  self.filecache = self.cache_dir / input_fname
432
270
  self.bids_dependencies = bids_dependencies
433
271
 
434
272
  if preload and not os.path.exists(self.filecache):
435
- self._download_s3()
273
+ self.filecache = downloader.download_s3_file(
274
+ self.s3file, self.filecache, self.s3_open_neuro
275
+ )
276
+ self.filenames = [self.filecache]
436
277
  preload = self.filecache
437
278
 
438
279
  super().__init__(
@@ -443,56 +284,47 @@ class EEGDashBaseRaw(BaseRaw):
443
284
  verbose=verbose,
444
285
  )
445
286
 
446
- def _get_s3path(self, filepath):
447
- return f"{self._AWS_BUCKET}/{filepath}"
448
-
449
- def _download_s3(self) -> None:
450
- self.filecache.parent.mkdir(parents=True, exist_ok=True)
451
- filesystem = s3fs.S3FileSystem(
452
- anon=True, client_kwargs={"region_name": "us-east-2"}
453
- )
454
- filesystem.download(self.s3file, self.filecache)
455
- self.filenames = [self.filecache]
456
-
457
- def _download_dependencies(self):
458
- filesystem = s3fs.S3FileSystem(
459
- anon=True, client_kwargs={"region_name": "us-east-2"}
460
- )
461
- for dep in self.bids_dependencies:
462
- s3path = self._get_s3path(dep)
463
- filepath = self.cache_dir / dep
464
- if not filepath.exists():
465
- filepath.parent.mkdir(parents=True, exist_ok=True)
466
- filesystem.download(s3path, filepath)
467
-
468
287
  def _read_segment(
469
288
  self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
470
289
  ):
290
+ """Read a segment of data, downloading if necessary."""
471
291
  if not os.path.exists(self.filecache): # not preload
472
- if self.bids_dependencies:
473
- self._download_dependencies()
474
- self._download_s3()
292
+ if self.bids_dependencies: # this is use only to sidecars for now
293
+ downloader.download_dependencies(
294
+ s3_bucket=self._AWS_BUCKET,
295
+ bids_dependencies=self.bids_dependencies,
296
+ bids_dependencies_original=None,
297
+ cache_dir=self.cache_dir,
298
+ dataset_folder=self.filecache,
299
+ record={},
300
+ s3_open_neuro=self.s3_open_neuro,
301
+ )
302
+ self.filecache = downloader.download_s3_file(
303
+ self.s3file, self.filecache, self.s3_open_neuro
304
+ )
305
+ self.filenames = [self.filecache]
475
306
  else: # not preload and file is not cached
476
307
  self.filenames = [self.filecache]
477
308
  return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
478
309
 
479
310
  def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
480
- """Read a chunk of data from the file."""
311
+ """Read a chunk of data from a local file."""
481
312
  _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
482
313
 
483
314
 
484
315
  class EEGBIDSDataset:
485
- """A one-stop shop interface to a local BIDS dataset containing EEG recordings.
316
+ """An interface to a local BIDS dataset containing EEG recordings.
486
317
 
487
- This is mainly tailored to the needs of EEGDash application and is used to centralize
488
- interactions with the BIDS dataset, such as parsing the metadata.
318
+ This class centralizes interactions with a BIDS dataset on the local
319
+ filesystem, providing methods to parse metadata, find files, and
320
+ retrieve BIDS-related information.
489
321
 
490
322
  Parameters
491
323
  ----------
492
- data_dir : str | Path
324
+ data_dir : str or Path
493
325
  The path to the local BIDS dataset directory.
494
326
  dataset : str
495
- A name for the dataset.
327
+ A name for the dataset (e.g., "ds002718").
496
328
 
497
329
  """
498
330
 
@@ -516,6 +348,14 @@ class EEGBIDSDataset:
516
348
  data_dir=None, # location of bids dataset
517
349
  dataset="", # dataset name
518
350
  ):
351
+ bids_lib = _soft_import("bids", purpose="digestion of datasets", strict=False)
352
+
353
+ if bids_lib is None:
354
+ raise ImportError(
355
+ "The 'pybids' package is required to use EEGBIDSDataset. "
356
+ "Please install it via 'pip install eegdash[digestion]'."
357
+ )
358
+
519
359
  if data_dir is None or not os.path.exists(data_dir):
520
360
  raise ValueError("data_dir must be specified and must exist")
521
361
  self.bidsdir = Path(data_dir)
@@ -527,7 +367,7 @@ class EEGBIDSDataset:
527
367
  raise AssertionError(
528
368
  f"BIDS directory '{dir_name}' does not correspond to dataset '{self.dataset}'"
529
369
  )
530
- self.layout = BIDSLayout(data_dir)
370
+ self.layout = bids_lib.BIDSLayout(data_dir)
531
371
 
532
372
  # get all recording files in the bids directory
533
373
  self.files = self._get_recordings(self.layout)
@@ -537,10 +377,17 @@ class EEGBIDSDataset:
537
377
  assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
538
378
 
539
379
  def check_eeg_dataset(self) -> bool:
540
- """Check if the dataset is EEG."""
380
+ """Check if the BIDS dataset contains EEG data.
381
+
382
+ Returns
383
+ -------
384
+ bool
385
+ True if the dataset's modality is EEG, False otherwise.
386
+
387
+ """
541
388
  return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
542
389
 
543
- def _get_recordings(self, layout: BIDSLayout) -> list[str]:
390
+ def _get_recordings(self, layout) -> list[str]:
544
391
  """Get a list of all EEG recording files in the BIDS layout."""
545
392
  files = []
546
393
  for ext, exts in self.RAW_EXTENSIONS.items():
@@ -550,14 +397,12 @@ class EEGBIDSDataset:
550
397
  return files
551
398
 
552
399
  def _get_relative_bidspath(self, filename: str) -> str:
553
- """Make the given file path relative to the BIDS directory."""
400
+ """Make a file path relative to the BIDS parent directory."""
554
401
  bids_parent_dir = self.bidsdir.parent.absolute()
555
402
  return str(Path(filename).relative_to(bids_parent_dir))
556
403
 
557
404
  def _get_property_from_filename(self, property: str, filename: str) -> str:
558
- """Parse a property out of a BIDS-compliant filename. Returns an empty string
559
- if not found.
560
- """
405
+ """Parse a BIDS entity from a filename."""
561
406
  import platform
562
407
 
563
408
  if platform.system() == "Windows":
@@ -567,159 +412,106 @@ class EEGBIDSDataset:
567
412
  return lookup.group(1) if lookup else ""
568
413
 
569
414
  def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
570
- """Internal helper to merge list of json files found by get_bids_file_inheritance,
571
- expecting the order (from left to right) is from lowest
572
- level to highest level, and return a merged dictionary
573
- """
415
+ """Merge a list of JSON files according to BIDS inheritance."""
574
416
  json_files.reverse()
575
417
  json_dict = {}
576
418
  for f in json_files:
577
- json_dict.update(json.load(open(f))) # FIXME: should close file
419
+ with open(f) as fp:
420
+ json_dict.update(json.load(fp))
578
421
  return json_dict
579
422
 
580
423
  def _get_bids_file_inheritance(
581
424
  self, path: str | Path, basename: str, extension: str
582
425
  ) -> list[Path]:
583
- """Get all file paths that apply to the basename file in the specified directory
584
- and that end with the specified suffix, recursively searching parent directories
585
- (following the BIDS inheritance principle in the order of lowest level first).
586
-
587
- Parameters
588
- ----------
589
- path : str | Path
590
- The directory path to search for files.
591
- basename : str
592
- BIDS file basename without _eeg.set extension for example
593
- extension : str
594
- Only consider files that end with the specified suffix; e.g. channels.tsv
595
-
596
- Returns
597
- -------
598
- list[Path]
599
- A list of file paths that match the given basename and extension.
600
-
601
- """
426
+ """Find all applicable metadata files using BIDS inheritance."""
602
427
  top_level_files = ["README", "dataset_description.json", "participants.tsv"]
603
428
  bids_files = []
604
429
 
605
- # check if path is str object
606
430
  if isinstance(path, str):
607
431
  path = Path(path)
608
- if not path.exists:
609
- raise ValueError("path {path} does not exist")
432
+ if not path.exists():
433
+ raise ValueError(f"path {path} does not exist")
610
434
 
611
- # check if file is in current path
612
435
  for file in os.listdir(path):
613
- # target_file = path / f"{cur_file_basename}_{extension}"
614
- if os.path.isfile(path / file):
615
- # check if file has extension extension
616
- # check if file basename has extension
617
- if file.endswith(extension):
618
- filepath = path / file
619
- bids_files.append(filepath)
620
-
621
- # check if file is in top level directory
436
+ if os.path.isfile(path / file) and file.endswith(extension):
437
+ bids_files.append(path / file)
438
+
622
439
  if any(file in os.listdir(path) for file in top_level_files):
623
440
  return bids_files
624
441
  else:
625
- # call get_bids_file_inheritance recursively with parent directory
626
442
  bids_files.extend(
627
443
  self._get_bids_file_inheritance(path.parent, basename, extension)
628
444
  )
629
445
  return bids_files
630
446
 
631
447
  def get_bids_metadata_files(
632
- self, filepath: str | Path, metadata_file_extension: list[str]
448
+ self, filepath: str | Path, metadata_file_extension: str
633
449
  ) -> list[Path]:
634
- """Retrieve all metadata file paths that apply to a given data file path and that
635
- end with a specific suffix (following the BIDS inheritance principle).
450
+ """Retrieve all metadata files that apply to a given data file.
451
+
452
+ Follows the BIDS inheritance principle to find all relevant metadata
453
+ files (e.g., ``channels.tsv``, ``eeg.json``) for a specific recording.
636
454
 
637
455
  Parameters
638
456
  ----------
639
- filepath: str | Path
640
- The filepath to get the associated metadata files for.
457
+ filepath : str or Path
458
+ The path to the data file.
641
459
  metadata_file_extension : str
642
- Consider only metadata files that end with the specified suffix,
643
- e.g., channels.tsv or eeg.json
460
+ The extension of the metadata file to search for (e.g., "channels.tsv").
644
461
 
645
462
  Returns
646
463
  -------
647
- list[Path]:
648
- A list of filepaths for all matching metadata files
464
+ list of Path
465
+ A list of paths to the matching metadata files.
649
466
 
650
467
  """
651
468
  if isinstance(filepath, str):
652
469
  filepath = Path(filepath)
653
- if not filepath.exists:
654
- raise ValueError("filepath {filepath} does not exist")
470
+ if not filepath.exists():
471
+ raise ValueError(f"filepath {filepath} does not exist")
655
472
  path, filename = os.path.split(filepath)
656
473
  basename = filename[: filename.rfind("_")]
657
- # metadata files
658
474
  meta_files = self._get_bids_file_inheritance(
659
475
  path, basename, metadata_file_extension
660
476
  )
661
477
  return meta_files
662
478
 
663
479
  def _scan_directory(self, directory: str, extension: str) -> list[Path]:
664
- """Return a list of file paths that end with the given extension in the specified
665
- directory. Ignores certain special directories like .git, .datalad, derivatives,
666
- and code.
667
- """
480
+ """Scan a directory for files with a given extension."""
668
481
  result_files = []
669
482
  directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
670
483
  with os.scandir(directory) as entries:
671
484
  for entry in entries:
672
485
  if entry.is_file() and entry.name.endswith(extension):
673
- result_files.append(entry.path)
674
- elif entry.is_dir():
675
- # check that entry path doesn't contain any name in ignore list
676
- if not any(name in entry.name for name in directory_to_ignore):
677
- result_files.append(entry.path) # Add directory to scan later
486
+ result_files.append(Path(entry.path))
487
+ elif entry.is_dir() and not any(
488
+ name in entry.name for name in directory_to_ignore
489
+ ):
490
+ result_files.append(Path(entry.path))
678
491
  return result_files
679
492
 
680
493
  def _get_files_with_extension_parallel(
681
494
  self, directory: str, extension: str = ".set", max_workers: int = -1
682
495
  ) -> list[Path]:
683
- """Efficiently scan a directory and its subdirectories for files that end with
684
- the given extension.
685
-
686
- Parameters
687
- ----------
688
- directory : str
689
- The root directory to scan for files.
690
- extension : str
691
- Only consider files that end with this suffix, e.g. '.set'.
692
- max_workers : int
693
- Optionally specify the maximum number of worker threads to use for parallel scanning.
694
- Defaults to all available CPU cores if set to -1.
695
-
696
- Returns
697
- -------
698
- list[Path]:
699
- A list of filepaths for all matching metadata files
700
-
701
- """
496
+ """Scan a directory tree in parallel for files with a given extension."""
702
497
  result_files = []
703
498
  dirs_to_scan = [directory]
704
499
 
705
- # Use joblib.Parallel and delayed to parallelize directory scanning
706
500
  while dirs_to_scan:
707
501
  logger.info(
708
502
  f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
709
503
  )
710
- # Run the scan_directory function in parallel across directories
711
504
  results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
712
505
  delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
713
506
  )
714
507
 
715
- # Reset the directories to scan and process the results
716
508
  dirs_to_scan = []
717
509
  for res in results:
718
510
  for path in res:
719
511
  if os.path.isdir(path):
720
- dirs_to_scan.append(path) # Queue up subdirectories to scan
512
+ dirs_to_scan.append(path)
721
513
  else:
722
- result_files.append(path) # Add files to the final result
514
+ result_files.append(path)
723
515
  logger.info(f"Found {len(result_files)} files.")
724
516
 
725
517
  return result_files
@@ -727,19 +519,29 @@ class EEGBIDSDataset:
727
519
  def load_and_preprocess_raw(
728
520
  self, raw_file: str, preprocess: bool = False
729
521
  ) -> np.ndarray:
730
- """Utility function to load a raw data file with MNE and apply some simple
731
- (hardcoded) preprocessing and return as a numpy array. Not meant for purposes
732
- other than testing or debugging.
522
+ """Load and optionally preprocess a raw data file.
523
+
524
+ This is a utility function for testing or debugging, not for general use.
525
+
526
+ Parameters
527
+ ----------
528
+ raw_file : str
529
+ Path to the raw EEGLAB file (.set).
530
+ preprocess : bool, default False
531
+ If True, apply a high-pass filter, notch filter, and resample the data.
532
+
533
+ Returns
534
+ -------
535
+ numpy.ndarray
536
+ The loaded and processed data as a NumPy array.
537
+
733
538
  """
734
539
  logger.info(f"Loading raw data from {raw_file}")
735
540
  EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
736
541
 
737
542
  if preprocess:
738
- # highpass filter
739
543
  EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
740
- # remove 60Hz line noise
741
544
  EEG = EEG.notch_filter(freqs=(60), verbose=False)
742
- # bring to common sampling rate
743
545
  sfreq = 128
744
546
  if EEG.info["sfreq"] != sfreq:
745
547
  EEG = EEG.resample(sfreq)
@@ -750,26 +552,35 @@ class EEGBIDSDataset:
750
552
  raise ValueError("Expect raw data to be CxT dimension")
751
553
  return mat_data
752
554
 
753
- def get_files(self) -> list[Path]:
754
- """Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
555
+ def get_files(self) -> list[str]:
556
+ """Get all EEG recording file paths in the BIDS dataset.
557
+
558
+ Returns
559
+ -------
560
+ list of str
561
+ A list of file paths for all valid EEG recordings.
562
+
563
+ """
755
564
  return self.files
756
565
 
757
566
  def resolve_bids_json(self, json_files: list[str]) -> dict:
758
- """Resolve the BIDS JSON files and return a dictionary of the resolved values.
567
+ """Resolve BIDS JSON inheritance and merge files.
759
568
 
760
569
  Parameters
761
570
  ----------
762
- json_files : list
763
- A list of JSON file paths to resolve in order of leaf level first.
571
+ json_files : list of str
572
+ A list of JSON file paths, ordered from the lowest (most specific)
573
+ to highest level of the BIDS hierarchy.
764
574
 
765
575
  Returns
766
576
  -------
767
- dict: A dictionary of the resolved values.
577
+ dict
578
+ A dictionary containing the merged JSON data.
768
579
 
769
580
  """
770
- if len(json_files) == 0:
581
+ if not json_files:
771
582
  raise ValueError("No JSON files provided")
772
- json_files.reverse() # TODO undeterministic
583
+ json_files.reverse()
773
584
 
774
585
  json_dict = {}
775
586
  for json_file in json_files:
@@ -778,8 +589,20 @@ class EEGBIDSDataset:
778
589
  return json_dict
779
590
 
780
591
  def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
781
- """Retrieve a specific attribute from the BIDS file metadata applicable
782
- to the provided recording file path.
592
+ """Retrieve a specific attribute from BIDS metadata.
593
+
594
+ Parameters
595
+ ----------
596
+ attribute : str
597
+ The name of the attribute to retrieve (e.g., "sfreq", "subject").
598
+ data_filepath : str
599
+ The path to the data file.
600
+
601
+ Returns
602
+ -------
603
+ Any
604
+ The value of the requested attribute, or None if not found.
605
+
783
606
  """
784
607
  entities = self.layout.parse_file_entities(data_filepath)
785
608
  bidsfile = self.layout.get(**entities)[0]
@@ -798,21 +621,59 @@ class EEGBIDSDataset:
798
621
  return attribute_value
799
622
 
800
623
  def channel_labels(self, data_filepath: str) -> list[str]:
801
- """Get a list of channel labels for the given data file path."""
624
+ """Get a list of channel labels from channels.tsv.
625
+
626
+ Parameters
627
+ ----------
628
+ data_filepath : str
629
+ The path to the data file.
630
+
631
+ Returns
632
+ -------
633
+ list of str
634
+ A list of channel names.
635
+
636
+ """
802
637
  channels_tsv = pd.read_csv(
803
638
  self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
804
639
  )
805
640
  return channels_tsv["name"].tolist()
806
641
 
807
642
  def channel_types(self, data_filepath: str) -> list[str]:
808
- """Get a list of channel types for the given data file path."""
643
+ """Get a list of channel types from channels.tsv.
644
+
645
+ Parameters
646
+ ----------
647
+ data_filepath : str
648
+ The path to the data file.
649
+
650
+ Returns
651
+ -------
652
+ list of str
653
+ A list of channel types.
654
+
655
+ """
809
656
  channels_tsv = pd.read_csv(
810
657
  self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
811
658
  )
812
659
  return channels_tsv["type"].tolist()
813
660
 
814
661
  def num_times(self, data_filepath: str) -> int:
815
- """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
662
+ """Get the number of time points in the recording.
663
+
664
+ Calculated from ``SamplingFrequency`` and ``RecordingDuration`` in eeg.json.
665
+
666
+ Parameters
667
+ ----------
668
+ data_filepath : str
669
+ The path to the data file.
670
+
671
+ Returns
672
+ -------
673
+ int
674
+ The approximate number of time points.
675
+
676
+ """
816
677
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
817
678
  eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
818
679
  return int(
@@ -820,35 +681,71 @@ class EEGBIDSDataset:
820
681
  )
821
682
 
822
683
  def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
823
- """Get BIDS participants.tsv record for the subject to which the given file
824
- path corresponds, as a dictionary.
684
+ """Get the participants.tsv record for a subject.
685
+
686
+ Parameters
687
+ ----------
688
+ data_filepath : str
689
+ The path to a data file belonging to the subject.
690
+
691
+ Returns
692
+ -------
693
+ dict
694
+ A dictionary of the subject's information from participants.tsv.
695
+
825
696
  """
826
- participants_tsv = pd.read_csv(
827
- self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
828
- )
829
- # if participants_tsv is not empty
697
+ participants_tsv_path = self.get_bids_metadata_files(
698
+ data_filepath, "participants.tsv"
699
+ )[0]
700
+ participants_tsv = pd.read_csv(participants_tsv_path, sep="\t")
830
701
  if participants_tsv.empty:
831
702
  return {}
832
- # set 'participant_id' as index
833
703
  participants_tsv.set_index("participant_id", inplace=True)
834
704
  subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
835
705
  return participants_tsv.loc[subject].to_dict()
836
706
 
837
707
  def eeg_json(self, data_filepath: str) -> dict[str, Any]:
838
- """Get BIDS eeg.json metadata for the given data file path."""
708
+ """Get the merged eeg.json metadata for a data file.
709
+
710
+ Parameters
711
+ ----------
712
+ data_filepath : str
713
+ The path to the data file.
714
+
715
+ Returns
716
+ -------
717
+ dict
718
+ The merged eeg.json metadata.
719
+
720
+ """
839
721
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
840
- eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
841
- return eeg_json_dict
722
+ return self._merge_json_inheritance(eeg_jsons)
842
723
 
843
724
  def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
844
- """Get BIDS channels.tsv metadata for the given data file path, as a dictionary
845
- of lists and/or single values.
725
+ """Get the channels.tsv metadata as a dictionary.
726
+
727
+ Parameters
728
+ ----------
729
+ data_filepath : str
730
+ The path to the data file.
731
+
732
+ Returns
733
+ -------
734
+ dict
735
+ The channels.tsv data, with columns as keys.
736
+
846
737
  """
847
- channels_tsv = pd.read_csv(
848
- self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
849
- )
850
- channel_tsv = channels_tsv.to_dict()
851
- # 'name' and 'type' now have a dictionary of index-value. Convert them to list
738
+ channels_tsv_path = self.get_bids_metadata_files(data_filepath, "channels.tsv")[
739
+ 0
740
+ ]
741
+ channels_tsv = pd.read_csv(channels_tsv_path, sep="\t")
742
+ channel_tsv_dict = channels_tsv.to_dict()
852
743
  for list_field in ["name", "type", "units"]:
853
- channel_tsv[list_field] = list(channel_tsv[list_field].values())
854
- return channel_tsv
744
+ if list_field in channel_tsv_dict:
745
+ channel_tsv_dict[list_field] = list(
746
+ channel_tsv_dict[list_field].values()
747
+ )
748
+ return channel_tsv_dict
749
+
750
+
751
+ __all__ = ["EEGDashBaseDataset", "EEGBIDSDataset", "EEGDashBaseRaw"]