eegdash 0.3.6.dev182011805__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +5 -4
- eegdash/api.py +515 -454
- eegdash/bids_eeg_metadata.py +254 -0
- eegdash/{dataset.py → const.py} +46 -93
- eegdash/data_utils.py +180 -45
- eegdash/dataset/__init__.py +4 -0
- eegdash/dataset/dataset.py +161 -0
- eegdash/dataset/dataset_summary.csv +256 -0
- eegdash/{registry.py → dataset/registry.py} +16 -6
- eegdash/paths.py +28 -0
- eegdash/utils.py +1 -1
- {eegdash-0.3.6.dev182011805.dist-info → eegdash-0.3.7.dist-info}/METADATA +13 -5
- {eegdash-0.3.6.dev182011805.dist-info → eegdash-0.3.7.dist-info}/RECORD +16 -14
- eegdash/data_config.py +0 -34
- eegdash/dataset_summary.csv +0 -256
- eegdash/preprocessing.py +0 -63
- {eegdash-0.3.6.dev182011805.dist-info → eegdash-0.3.7.dist-info}/WHEEL +0 -0
- {eegdash-0.3.6.dev182011805.dist-info → eegdash-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.3.6.dev182011805.dist-info → eegdash-0.3.7.dist-info}/top_level.txt +0 -0
eegdash/data_utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
import traceback
|
|
6
7
|
import warnings
|
|
8
|
+
from contextlib import redirect_stderr
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
from typing import Any
|
|
9
11
|
|
|
@@ -21,6 +23,8 @@ from mne_bids import BIDSPath
|
|
|
21
23
|
|
|
22
24
|
from braindecode.datasets import BaseDataset
|
|
23
25
|
|
|
26
|
+
from .paths import get_default_cache_dir
|
|
27
|
+
|
|
24
28
|
logger = logging.getLogger("eegdash")
|
|
25
29
|
|
|
26
30
|
|
|
@@ -57,7 +61,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
57
61
|
super().__init__(None, **kwargs)
|
|
58
62
|
self.record = record
|
|
59
63
|
self.cache_dir = Path(cache_dir)
|
|
60
|
-
self.bids_kwargs = self.
|
|
64
|
+
self.bids_kwargs = self._get_raw_bids_args()
|
|
61
65
|
|
|
62
66
|
if s3_bucket:
|
|
63
67
|
self.s3_bucket = s3_bucket
|
|
@@ -66,8 +70,27 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
66
70
|
self.s3_bucket = self._AWS_BUCKET
|
|
67
71
|
self.s3_open_neuro = True
|
|
68
72
|
|
|
69
|
-
|
|
70
|
-
|
|
73
|
+
# Compute a dataset folder name under cache_dir that encodes preprocessing
|
|
74
|
+
# (e.g., bdf, mini) to avoid overlapping with the original dataset cache.
|
|
75
|
+
self.dataset_folder = record.get("dataset", "")
|
|
76
|
+
if s3_bucket:
|
|
77
|
+
suffixes: list[str] = []
|
|
78
|
+
bucket_lower = str(s3_bucket).lower()
|
|
79
|
+
if "bdf" in bucket_lower:
|
|
80
|
+
suffixes.append("bdf")
|
|
81
|
+
if "mini" in bucket_lower:
|
|
82
|
+
suffixes.append("mini")
|
|
83
|
+
if suffixes:
|
|
84
|
+
self.dataset_folder = f"{self.dataset_folder}-{'-'.join(suffixes)}"
|
|
85
|
+
|
|
86
|
+
# Place files under the dataset-specific folder (with suffix if any)
|
|
87
|
+
rel = Path(record["bidspath"]) # usually starts with dataset id
|
|
88
|
+
if rel.parts and rel.parts[0] == record.get("dataset"):
|
|
89
|
+
rel = Path(self.dataset_folder, *rel.parts[1:])
|
|
90
|
+
else:
|
|
91
|
+
rel = Path(self.dataset_folder) / rel
|
|
92
|
+
self.filecache = self.cache_dir / rel
|
|
93
|
+
self.bids_root = self.cache_dir / self.dataset_folder
|
|
71
94
|
self.bidspath = BIDSPath(
|
|
72
95
|
root=self.bids_root,
|
|
73
96
|
datatype="eeg",
|
|
@@ -75,7 +98,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
75
98
|
**self.bids_kwargs,
|
|
76
99
|
)
|
|
77
100
|
|
|
78
|
-
self.s3file = self.
|
|
101
|
+
self.s3file = self._get_s3path(record["bidspath"])
|
|
79
102
|
self.bids_dependencies = record["bidsdependencies"]
|
|
80
103
|
# Temporary fix for BIDS dependencies path
|
|
81
104
|
# just to release to the competition
|
|
@@ -87,7 +110,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
87
110
|
|
|
88
111
|
self._raw = None
|
|
89
112
|
|
|
90
|
-
def
|
|
113
|
+
def _get_s3path(self, filepath: str) -> str:
|
|
91
114
|
"""Helper to form an AWS S3 URI for the given relative filepath."""
|
|
92
115
|
return f"{self.s3_bucket}/{filepath}"
|
|
93
116
|
|
|
@@ -141,14 +164,22 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
141
164
|
if dep.endswith(".set"):
|
|
142
165
|
dep = dep[:-4] + ".bdf"
|
|
143
166
|
|
|
144
|
-
s3path = self.
|
|
167
|
+
s3path = self._get_s3path(dep)
|
|
145
168
|
if not self.s3_open_neuro:
|
|
146
169
|
dep = self.bids_dependencies_original[i]
|
|
147
170
|
|
|
148
|
-
|
|
171
|
+
dep_path = Path(dep)
|
|
172
|
+
if dep_path.parts and dep_path.parts[0] == self.record.get("dataset"):
|
|
173
|
+
dep_local = Path(self.dataset_folder, *dep_path.parts[1:])
|
|
174
|
+
else:
|
|
175
|
+
dep_local = Path(self.dataset_folder) / dep_path
|
|
176
|
+
filepath = self.cache_dir / dep_local
|
|
149
177
|
if not self.s3_open_neuro:
|
|
178
|
+
if filepath.suffix == ".set":
|
|
179
|
+
filepath = filepath.with_suffix(".bdf")
|
|
150
180
|
if self.filecache.suffix == ".set":
|
|
151
181
|
self.filecache = self.filecache.with_suffix(".bdf")
|
|
182
|
+
|
|
152
183
|
# here, we download the dependency and it is fine
|
|
153
184
|
# in the case of the competition.
|
|
154
185
|
if not filepath.exists():
|
|
@@ -174,15 +205,21 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
174
205
|
)
|
|
175
206
|
filesystem.get(s3path, filepath, callback=callback)
|
|
176
207
|
|
|
177
|
-
def
|
|
208
|
+
def _get_raw_bids_args(self) -> dict[str, Any]:
|
|
178
209
|
"""Helper to restrict the metadata record to the fields needed to locate a BIDS
|
|
179
210
|
recording.
|
|
180
211
|
"""
|
|
181
212
|
desired_fields = ["subject", "session", "task", "run"]
|
|
182
213
|
return {k: self.record[k] for k in desired_fields if self.record[k]}
|
|
183
214
|
|
|
184
|
-
def
|
|
215
|
+
def _ensure_raw(self) -> None:
|
|
185
216
|
"""Download the S3 file and BIDS dependencies if not already cached."""
|
|
217
|
+
# TO-DO: remove this once is fixed on the our side
|
|
218
|
+
# for the competition
|
|
219
|
+
if not self.s3_open_neuro:
|
|
220
|
+
self.bidspath = self.bidspath.update(extension=".bdf")
|
|
221
|
+
self.filecache = self.filecache.with_suffix(".bdf")
|
|
222
|
+
|
|
186
223
|
if not os.path.exists(self.filecache): # not preload
|
|
187
224
|
if self.bids_dependencies:
|
|
188
225
|
self._download_dependencies()
|
|
@@ -191,14 +228,50 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
191
228
|
# capturing any warnings
|
|
192
229
|
# to-do: remove this once is fixed on the mne-bids side.
|
|
193
230
|
with warnings.catch_warnings(record=True) as w:
|
|
231
|
+
# Ensure all warnings are captured into 'w' and not shown to users
|
|
232
|
+
warnings.simplefilter("always")
|
|
194
233
|
try:
|
|
195
|
-
#
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
234
|
+
# mne-bids emits RuntimeWarnings to stderr; silence stderr during read
|
|
235
|
+
_stderr_buffer = io.StringIO()
|
|
236
|
+
with redirect_stderr(_stderr_buffer):
|
|
237
|
+
self._raw = mne_bids.read_raw_bids(
|
|
238
|
+
bids_path=self.bidspath, verbose="ERROR"
|
|
239
|
+
)
|
|
240
|
+
# Parse unmapped participants.tsv fields reported by mne-bids and
|
|
241
|
+
# inject them into Raw.info and the dataset description generically.
|
|
242
|
+
extras = self._extract_unmapped_participants_from_warnings(w)
|
|
243
|
+
if extras:
|
|
244
|
+
# 1) Attach to Raw.info under subject_info.participants_extras
|
|
245
|
+
try:
|
|
246
|
+
subject_info = self._raw.info.get("subject_info") or {}
|
|
247
|
+
if not isinstance(subject_info, dict):
|
|
248
|
+
subject_info = {}
|
|
249
|
+
pe = subject_info.get("participants_extras") or {}
|
|
250
|
+
if not isinstance(pe, dict):
|
|
251
|
+
pe = {}
|
|
252
|
+
# Merge without overwriting
|
|
253
|
+
for k, v in extras.items():
|
|
254
|
+
pe.setdefault(k, v)
|
|
255
|
+
subject_info["participants_extras"] = pe
|
|
256
|
+
self._raw.info["subject_info"] = subject_info
|
|
257
|
+
except Exception:
|
|
258
|
+
# Non-fatal; continue
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
# 2) Also add to this dataset's description, if possible, so
|
|
262
|
+
# targets can be selected later without naming specifics.
|
|
263
|
+
try:
|
|
264
|
+
import pandas as _pd # local import to avoid top-level cost
|
|
265
|
+
|
|
266
|
+
if isinstance(self.description, dict):
|
|
267
|
+
for k, v in extras.items():
|
|
268
|
+
self.description.setdefault(k, v)
|
|
269
|
+
elif isinstance(self.description, _pd.Series):
|
|
270
|
+
for k, v in extras.items():
|
|
271
|
+
if k not in self.description.index:
|
|
272
|
+
self.description.loc[k] = v
|
|
273
|
+
except Exception:
|
|
274
|
+
pass
|
|
202
275
|
except Exception as e:
|
|
203
276
|
logger.error(
|
|
204
277
|
f"Error while reading BIDS file: {self.bidspath}\n"
|
|
@@ -208,10 +281,60 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
208
281
|
logger.error(f"Exception: {e}")
|
|
209
282
|
logger.error(traceback.format_exc())
|
|
210
283
|
raise e
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
284
|
+
# Filter noisy mapping notices from mne-bids; surface others
|
|
285
|
+
for captured_warning in w:
|
|
286
|
+
try:
|
|
287
|
+
msg = str(captured_warning.message)
|
|
288
|
+
except Exception:
|
|
289
|
+
continue
|
|
290
|
+
# Suppress verbose participants mapping messages
|
|
291
|
+
if "Unable to map the following column" in msg and "MNE" in msg:
|
|
292
|
+
logger.debug(
|
|
293
|
+
"Suppressed mne-bids mapping warning while reading BIDS file: %s",
|
|
294
|
+
msg,
|
|
295
|
+
)
|
|
296
|
+
continue
|
|
297
|
+
logger.warning("Warning while reading BIDS file: %s", msg)
|
|
298
|
+
|
|
299
|
+
def _extract_unmapped_participants_from_warnings(
|
|
300
|
+
self, warnings_list: list[Any]
|
|
301
|
+
) -> dict[str, Any]:
|
|
302
|
+
"""Scan captured warnings from mne-bids and extract unmapped participants.tsv
|
|
303
|
+
entries in a generic way.
|
|
304
|
+
|
|
305
|
+
Optionally, the column name can carry a note in parentheses that we ignore
|
|
306
|
+
for key/value extraction. Returns a mapping of column name -> raw value.
|
|
307
|
+
"""
|
|
308
|
+
extras: dict[str, Any] = {}
|
|
309
|
+
header = "Unable to map the following column(s) to MNE:"
|
|
310
|
+
for wr in warnings_list:
|
|
311
|
+
try:
|
|
312
|
+
msg = str(wr.message)
|
|
313
|
+
except Exception:
|
|
314
|
+
continue
|
|
315
|
+
if header not in msg:
|
|
316
|
+
continue
|
|
317
|
+
lines = msg.splitlines()
|
|
318
|
+
# Find the header line, then parse subsequent lines as entries
|
|
319
|
+
try:
|
|
320
|
+
idx = next(i for i, ln in enumerate(lines) if header in ln)
|
|
321
|
+
except StopIteration:
|
|
322
|
+
idx = -1
|
|
323
|
+
for line in lines[idx + 1 :]:
|
|
324
|
+
line = line.strip()
|
|
325
|
+
if not line:
|
|
326
|
+
continue
|
|
327
|
+
# Pattern: <col>(optional note): <value>
|
|
328
|
+
# Examples: "gender: F", "Ethnicity: Indian", "foo (ignored): bar"
|
|
329
|
+
m = re.match(r"^([^:]+?)(?:\s*\([^)]*\))?\s*:\s*(.*)$", line)
|
|
330
|
+
if not m:
|
|
331
|
+
continue
|
|
332
|
+
col = m.group(1).strip()
|
|
333
|
+
val = m.group(2).strip()
|
|
334
|
+
# Keep original column names as provided to stay agnostic
|
|
335
|
+
if col and col not in extras:
|
|
336
|
+
extras[col] = val
|
|
337
|
+
return extras
|
|
215
338
|
|
|
216
339
|
# === BaseDataset and PyTorch Dataset interface ===
|
|
217
340
|
|
|
@@ -230,11 +353,16 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
230
353
|
def __len__(self) -> int:
|
|
231
354
|
"""Return the number of samples in the dataset."""
|
|
232
355
|
if self._raw is None:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
356
|
+
if (
|
|
357
|
+
self.record["ntimes"] is None
|
|
358
|
+
or self.record["sampling_frequency"] is None
|
|
359
|
+
):
|
|
360
|
+
self._ensure_raw()
|
|
361
|
+
else:
|
|
362
|
+
# FIXME: this is a bit strange and should definitely not change as a side effect
|
|
363
|
+
# of accessing the data (which it will, since ntimes is the actual length but rounded down)
|
|
364
|
+
return int(self.record["ntimes"] * self.record["sampling_frequency"])
|
|
365
|
+
return len(self._raw)
|
|
238
366
|
|
|
239
367
|
@property
|
|
240
368
|
def raw(self):
|
|
@@ -242,7 +370,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
242
370
|
retrieval if not yet done so.
|
|
243
371
|
"""
|
|
244
372
|
if self._raw is None:
|
|
245
|
-
self.
|
|
373
|
+
self._ensure_raw()
|
|
246
374
|
return self._raw
|
|
247
375
|
|
|
248
376
|
@raw.setter
|
|
@@ -284,7 +412,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
284
412
|
metadata: dict[str, Any],
|
|
285
413
|
preload: bool = False,
|
|
286
414
|
*,
|
|
287
|
-
cache_dir: str =
|
|
415
|
+
cache_dir: str | None = None,
|
|
288
416
|
bids_dependencies: list[str] = [],
|
|
289
417
|
verbose: Any = None,
|
|
290
418
|
):
|
|
@@ -300,8 +428,9 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
300
428
|
chtype = "eog"
|
|
301
429
|
ch_types.append(chtype)
|
|
302
430
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
303
|
-
|
|
304
|
-
self.
|
|
431
|
+
|
|
432
|
+
self.s3file = self._get_s3path(input_fname)
|
|
433
|
+
self.cache_dir = Path(cache_dir) if cache_dir else get_default_cache_dir()
|
|
305
434
|
self.filecache = self.cache_dir / input_fname
|
|
306
435
|
self.bids_dependencies = bids_dependencies
|
|
307
436
|
|
|
@@ -317,7 +446,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
317
446
|
verbose=verbose,
|
|
318
447
|
)
|
|
319
448
|
|
|
320
|
-
def
|
|
449
|
+
def _get_s3path(self, filepath):
|
|
321
450
|
return f"{self._AWS_BUCKET}/{filepath}"
|
|
322
451
|
|
|
323
452
|
def _download_s3(self) -> None:
|
|
@@ -333,7 +462,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
333
462
|
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
334
463
|
)
|
|
335
464
|
for dep in self.bids_dependencies:
|
|
336
|
-
s3path = self.
|
|
465
|
+
s3path = self._get_s3path(dep)
|
|
337
466
|
filepath = self.cache_dir / dep
|
|
338
467
|
if not filepath.exists():
|
|
339
468
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -394,11 +523,17 @@ class EEGBIDSDataset:
|
|
|
394
523
|
raise ValueError("data_dir must be specified and must exist")
|
|
395
524
|
self.bidsdir = Path(data_dir)
|
|
396
525
|
self.dataset = dataset
|
|
397
|
-
|
|
526
|
+
# Accept exact dataset folder or a variant with informative suffixes
|
|
527
|
+
# (e.g., dsXXXXX-bdf, dsXXXXX-bdf-mini) to avoid collisions.
|
|
528
|
+
dir_name = self.bidsdir.name
|
|
529
|
+
if not (dir_name == self.dataset or dir_name.startswith(self.dataset + "-")):
|
|
530
|
+
raise AssertionError(
|
|
531
|
+
f"BIDS directory '{dir_name}' does not correspond to dataset '{self.dataset}'"
|
|
532
|
+
)
|
|
398
533
|
self.layout = BIDSLayout(data_dir)
|
|
399
534
|
|
|
400
535
|
# get all recording files in the bids directory
|
|
401
|
-
self.files = self.
|
|
536
|
+
self.files = self._get_recordings(self.layout)
|
|
402
537
|
assert len(self.files) > 0, ValueError(
|
|
403
538
|
"Unable to construct EEG dataset. No EEG recordings found."
|
|
404
539
|
)
|
|
@@ -408,7 +543,7 @@ class EEGBIDSDataset:
|
|
|
408
543
|
"""Check if the dataset is EEG."""
|
|
409
544
|
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
|
|
410
545
|
|
|
411
|
-
def
|
|
546
|
+
def _get_recordings(self, layout: BIDSLayout) -> list[str]:
|
|
412
547
|
"""Get a list of all EEG recording files in the BIDS layout."""
|
|
413
548
|
files = []
|
|
414
549
|
for ext, exts in self.RAW_EXTENSIONS.items():
|
|
@@ -417,12 +552,12 @@ class EEGBIDSDataset:
|
|
|
417
552
|
break
|
|
418
553
|
return files
|
|
419
554
|
|
|
420
|
-
def
|
|
555
|
+
def _get_relative_bidspath(self, filename: str) -> str:
|
|
421
556
|
"""Make the given file path relative to the BIDS directory."""
|
|
422
557
|
bids_parent_dir = self.bidsdir.parent.absolute()
|
|
423
558
|
return str(Path(filename).relative_to(bids_parent_dir))
|
|
424
559
|
|
|
425
|
-
def
|
|
560
|
+
def _get_property_from_filename(self, property: str, filename: str) -> str:
|
|
426
561
|
"""Parse a property out of a BIDS-compliant filename. Returns an empty string
|
|
427
562
|
if not found.
|
|
428
563
|
"""
|
|
@@ -434,7 +569,7 @@ class EEGBIDSDataset:
|
|
|
434
569
|
lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
|
|
435
570
|
return lookup.group(1) if lookup else ""
|
|
436
571
|
|
|
437
|
-
def
|
|
572
|
+
def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
|
|
438
573
|
"""Internal helper to merge list of json files found by get_bids_file_inheritance,
|
|
439
574
|
expecting the order (from left to right) is from lowest
|
|
440
575
|
level to highest level, and return a merged dictionary
|
|
@@ -445,7 +580,7 @@ class EEGBIDSDataset:
|
|
|
445
580
|
json_dict.update(json.load(open(f))) # FIXME: should close file
|
|
446
581
|
return json_dict
|
|
447
582
|
|
|
448
|
-
def
|
|
583
|
+
def _get_bids_file_inheritance(
|
|
449
584
|
self, path: str | Path, basename: str, extension: str
|
|
450
585
|
) -> list[Path]:
|
|
451
586
|
"""Get all file paths that apply to the basename file in the specified directory
|
|
@@ -492,7 +627,7 @@ class EEGBIDSDataset:
|
|
|
492
627
|
else:
|
|
493
628
|
# call get_bids_file_inheritance recursively with parent directory
|
|
494
629
|
bids_files.extend(
|
|
495
|
-
self.
|
|
630
|
+
self._get_bids_file_inheritance(path.parent, basename, extension)
|
|
496
631
|
)
|
|
497
632
|
return bids_files
|
|
498
633
|
|
|
@@ -523,12 +658,12 @@ class EEGBIDSDataset:
|
|
|
523
658
|
path, filename = os.path.split(filepath)
|
|
524
659
|
basename = filename[: filename.rfind("_")]
|
|
525
660
|
# metadata files
|
|
526
|
-
meta_files = self.
|
|
661
|
+
meta_files = self._get_bids_file_inheritance(
|
|
527
662
|
path, basename, metadata_file_extension
|
|
528
663
|
)
|
|
529
664
|
return meta_files
|
|
530
665
|
|
|
531
|
-
def
|
|
666
|
+
def _scan_directory(self, directory: str, extension: str) -> list[Path]:
|
|
532
667
|
"""Return a list of file paths that end with the given extension in the specified
|
|
533
668
|
directory. Ignores certain special directories like .git, .datalad, derivatives,
|
|
534
669
|
and code.
|
|
@@ -545,7 +680,7 @@ class EEGBIDSDataset:
|
|
|
545
680
|
result_files.append(entry.path) # Add directory to scan later
|
|
546
681
|
return result_files
|
|
547
682
|
|
|
548
|
-
def
|
|
683
|
+
def _get_files_with_extension_parallel(
|
|
549
684
|
self, directory: str, extension: str = ".set", max_workers: int = -1
|
|
550
685
|
) -> list[Path]:
|
|
551
686
|
"""Efficiently scan a directory and its subdirectories for files that end with
|
|
@@ -577,7 +712,7 @@ class EEGBIDSDataset:
|
|
|
577
712
|
)
|
|
578
713
|
# Run the scan_directory function in parallel across directories
|
|
579
714
|
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
|
|
580
|
-
delayed(self.
|
|
715
|
+
delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
|
|
581
716
|
)
|
|
582
717
|
|
|
583
718
|
# Reset the directories to scan and process the results
|
|
@@ -682,7 +817,7 @@ class EEGBIDSDataset:
|
|
|
682
817
|
def num_times(self, data_filepath: str) -> int:
|
|
683
818
|
"""Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
|
|
684
819
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
685
|
-
eeg_json_dict = self.
|
|
820
|
+
eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
|
|
686
821
|
return int(
|
|
687
822
|
eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
|
|
688
823
|
)
|
|
@@ -705,7 +840,7 @@ class EEGBIDSDataset:
|
|
|
705
840
|
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
|
|
706
841
|
"""Get BIDS eeg.json metadata for the given data file path."""
|
|
707
842
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
708
|
-
eeg_json_dict = self.
|
|
843
|
+
eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
|
|
709
844
|
return eeg_json_dict
|
|
710
845
|
|
|
711
846
|
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from mne.utils import warn
|
|
5
|
+
|
|
6
|
+
from ..api import EEGDashDataset
|
|
7
|
+
from ..bids_eeg_metadata import build_query_from_kwargs
|
|
8
|
+
from ..const import RELEASE_TO_OPENNEURO_DATASET_MAP, SUBJECT_MINI_RELEASE_MAP
|
|
9
|
+
from .registry import register_openneuro_datasets
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("eegdash")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EEGChallengeDataset(EEGDashDataset):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
release: str,
|
|
18
|
+
cache_dir: str,
|
|
19
|
+
mini: bool = True,
|
|
20
|
+
query: dict | None = None,
|
|
21
|
+
s3_bucket: str | None = "s3://nmdatasets/NeurIPS25",
|
|
22
|
+
**kwargs,
|
|
23
|
+
):
|
|
24
|
+
"""Create a new EEGDashDataset from a given query or local BIDS dataset directory
|
|
25
|
+
and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
|
|
26
|
+
instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
release: str
|
|
31
|
+
Release name. Can be one of ["R1", ..., "R11"]
|
|
32
|
+
mini: bool, default True
|
|
33
|
+
Whether to use the mini-release version of the dataset. It is recommended
|
|
34
|
+
to use the mini version for faster training and evaluation.
|
|
35
|
+
query : dict | None
|
|
36
|
+
Optionally a dictionary that specifies a query to be executed,
|
|
37
|
+
in addition to the dataset (automatically inferred from the release argument).
|
|
38
|
+
See EEGDash.find() for details on the query format.
|
|
39
|
+
cache_dir : str
|
|
40
|
+
A directory where the dataset will be cached locally.
|
|
41
|
+
s3_bucket : str | None
|
|
42
|
+
An optional S3 bucket URI to use instead of the
|
|
43
|
+
default OpenNeuro bucket for loading data files.
|
|
44
|
+
kwargs : dict
|
|
45
|
+
Additional keyword arguments to be passed to the EEGDashDataset
|
|
46
|
+
constructor.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
self.release = release
|
|
50
|
+
self.mini = mini
|
|
51
|
+
|
|
52
|
+
if release not in RELEASE_TO_OPENNEURO_DATASET_MAP:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Unknown release: {release}, expected one of {list(RELEASE_TO_OPENNEURO_DATASET_MAP.keys())}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
dataset_parameters = []
|
|
58
|
+
if isinstance(release, str):
|
|
59
|
+
dataset_parameters.append(RELEASE_TO_OPENNEURO_DATASET_MAP[release])
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Unknown release type: {type(release)}, the expected type is str."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if query and "dataset" in query:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Query using the parameters `dataset` with the class EEGChallengeDataset is not possible."
|
|
68
|
+
"Please use the release argument instead, or the object EEGDashDataset instead."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if self.mini:
|
|
72
|
+
# When using the mini release, restrict subjects to the predefined subset.
|
|
73
|
+
# If the user specifies subject(s), ensure they all belong to the mini subset;
|
|
74
|
+
# otherwise, default to the full mini subject list for this release.
|
|
75
|
+
|
|
76
|
+
allowed_subjects = set(SUBJECT_MINI_RELEASE_MAP[release])
|
|
77
|
+
|
|
78
|
+
# Normalize potential 'subjects' -> 'subject' for convenience
|
|
79
|
+
if "subjects" in kwargs and "subject" not in kwargs:
|
|
80
|
+
kwargs["subject"] = kwargs.pop("subjects")
|
|
81
|
+
|
|
82
|
+
# Collect user-requested subjects from kwargs/query. We canonicalize
|
|
83
|
+
# kwargs via build_query_from_kwargs to leverage existing validation,
|
|
84
|
+
# and support Mongo-style {"$in": [...]} shapes from a raw query.
|
|
85
|
+
requested_subjects: list[str] = []
|
|
86
|
+
|
|
87
|
+
# From kwargs
|
|
88
|
+
if "subject" in kwargs and kwargs["subject"] is not None:
|
|
89
|
+
# Use the shared query builder to normalize scalars/lists
|
|
90
|
+
built = build_query_from_kwargs(subject=kwargs["subject"])
|
|
91
|
+
s_val = built.get("subject")
|
|
92
|
+
if isinstance(s_val, dict) and "$in" in s_val:
|
|
93
|
+
requested_subjects.extend(list(s_val["$in"]))
|
|
94
|
+
elif s_val is not None:
|
|
95
|
+
requested_subjects.append(s_val) # type: ignore[arg-type]
|
|
96
|
+
|
|
97
|
+
# From query (top-level only)
|
|
98
|
+
if query and isinstance(query, dict) and "subject" in query:
|
|
99
|
+
qval = query["subject"]
|
|
100
|
+
if isinstance(qval, dict) and "$in" in qval:
|
|
101
|
+
requested_subjects.extend(list(qval["$in"]))
|
|
102
|
+
elif isinstance(qval, (list, tuple, set)):
|
|
103
|
+
requested_subjects.extend(list(qval))
|
|
104
|
+
elif qval is not None:
|
|
105
|
+
requested_subjects.append(qval)
|
|
106
|
+
|
|
107
|
+
# Validate if any subjects were explicitly requested
|
|
108
|
+
if requested_subjects:
|
|
109
|
+
invalid = sorted(
|
|
110
|
+
{s for s in requested_subjects if s not in allowed_subjects}
|
|
111
|
+
)
|
|
112
|
+
if invalid:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"Some requested subject(s) are not part of the mini release for "
|
|
115
|
+
f"{release}: {invalid}. Allowed subjects: {sorted(allowed_subjects)}"
|
|
116
|
+
)
|
|
117
|
+
# Do not override user selection; keep their (validated) subjects as-is.
|
|
118
|
+
else:
|
|
119
|
+
# No subject specified by the user: default to the full mini subset
|
|
120
|
+
kwargs["subject"] = sorted(allowed_subjects)
|
|
121
|
+
|
|
122
|
+
s3_bucket = f"{s3_bucket}/{release}_mini_L100_bdf"
|
|
123
|
+
else:
|
|
124
|
+
s3_bucket = f"{s3_bucket}/{release}_L100_bdf"
|
|
125
|
+
|
|
126
|
+
warn(
|
|
127
|
+
"\n\n"
|
|
128
|
+
"[EEGChallengeDataset] EEG 2025 Competition Data Notice:\n"
|
|
129
|
+
"-------------------------------------------------------\n"
|
|
130
|
+
"This object loads the HBN dataset that has been preprocessed for the EEG Challenge:\n"
|
|
131
|
+
" - Downsampled from 500Hz to 100Hz\n"
|
|
132
|
+
" - Bandpass filtered (0.5–50 Hz)\n"
|
|
133
|
+
"\n"
|
|
134
|
+
"For full preprocessing details, see:\n"
|
|
135
|
+
" https://github.com/eeg2025/downsample-datasets\n"
|
|
136
|
+
"\n"
|
|
137
|
+
"IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.\n"
|
|
138
|
+
"If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
|
|
139
|
+
"\n",
|
|
140
|
+
UserWarning,
|
|
141
|
+
module="eegdash",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
super().__init__(
|
|
145
|
+
dataset=RELEASE_TO_OPENNEURO_DATASET_MAP[release],
|
|
146
|
+
query=query,
|
|
147
|
+
cache_dir=cache_dir,
|
|
148
|
+
s3_bucket=s3_bucket,
|
|
149
|
+
_suppress_comp_warning=True,
|
|
150
|
+
**kwargs,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
registered_classes = register_openneuro_datasets(
|
|
155
|
+
summary_file=Path(__file__).with_name("dataset_summary.csv"),
|
|
156
|
+
base_class=EEGDashDataset,
|
|
157
|
+
namespace=globals(),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
__all__ = ["EEGChallengeDataset"] + list(registered_classes.keys())
|