eegdash 0.3.7.dev177024734__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +5 -5
- eegdash/api.py +528 -460
- eegdash/bids_eeg_metadata.py +254 -0
- eegdash/const.py +48 -0
- eegdash/data_utils.py +177 -45
- eegdash/dataset/__init__.py +4 -0
- eegdash/{dataset.py → dataset/dataset.py} +53 -10
- eegdash/dataset/dataset_summary.csv +256 -0
- eegdash/{registry.py → dataset/registry.py} +3 -3
- eegdash/hbn/__init__.py +17 -0
- eegdash/hbn/windows.py +305 -0
- eegdash/paths.py +28 -0
- eegdash/utils.py +1 -1
- {eegdash-0.3.7.dev177024734.dist-info → eegdash-0.3.8.dist-info}/METADATA +11 -5
- eegdash-0.3.8.dist-info/RECORD +35 -0
- eegdash/data_config.py +0 -34
- eegdash/dataset_summary.csv +0 -256
- eegdash-0.3.7.dev177024734.dist-info/RECORD +0 -31
- /eegdash/{preprocessing.py → hbn/preprocessing.py} +0 -0
- {eegdash-0.3.7.dev177024734.dist-info → eegdash-0.3.8.dist-info}/WHEEL +0 -0
- {eegdash-0.3.7.dev177024734.dist-info → eegdash-0.3.8.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.3.7.dev177024734.dist-info → eegdash-0.3.8.dist-info}/top_level.txt +0 -0
eegdash/data_utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
import traceback
|
|
6
7
|
import warnings
|
|
8
|
+
from contextlib import redirect_stderr
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
from typing import Any
|
|
9
11
|
|
|
@@ -21,6 +23,8 @@ from mne_bids import BIDSPath
|
|
|
21
23
|
|
|
22
24
|
from braindecode.datasets import BaseDataset
|
|
23
25
|
|
|
26
|
+
from .paths import get_default_cache_dir
|
|
27
|
+
|
|
24
28
|
logger = logging.getLogger("eegdash")
|
|
25
29
|
|
|
26
30
|
|
|
@@ -57,7 +61,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
57
61
|
super().__init__(None, **kwargs)
|
|
58
62
|
self.record = record
|
|
59
63
|
self.cache_dir = Path(cache_dir)
|
|
60
|
-
self.bids_kwargs = self.
|
|
64
|
+
self.bids_kwargs = self._get_raw_bids_args()
|
|
61
65
|
|
|
62
66
|
if s3_bucket:
|
|
63
67
|
self.s3_bucket = s3_bucket
|
|
@@ -66,8 +70,27 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
66
70
|
self.s3_bucket = self._AWS_BUCKET
|
|
67
71
|
self.s3_open_neuro = True
|
|
68
72
|
|
|
69
|
-
|
|
70
|
-
|
|
73
|
+
# Compute a dataset folder name under cache_dir that encodes preprocessing
|
|
74
|
+
# (e.g., bdf, mini) to avoid overlapping with the original dataset cache.
|
|
75
|
+
self.dataset_folder = record.get("dataset", "")
|
|
76
|
+
if s3_bucket:
|
|
77
|
+
suffixes: list[str] = []
|
|
78
|
+
bucket_lower = str(s3_bucket).lower()
|
|
79
|
+
if "bdf" in bucket_lower:
|
|
80
|
+
suffixes.append("bdf")
|
|
81
|
+
if "mini" in bucket_lower:
|
|
82
|
+
suffixes.append("mini")
|
|
83
|
+
if suffixes:
|
|
84
|
+
self.dataset_folder = f"{self.dataset_folder}-{'-'.join(suffixes)}"
|
|
85
|
+
|
|
86
|
+
# Place files under the dataset-specific folder (with suffix if any)
|
|
87
|
+
rel = Path(record["bidspath"]) # usually starts with dataset id
|
|
88
|
+
if rel.parts and rel.parts[0] == record.get("dataset"):
|
|
89
|
+
rel = Path(self.dataset_folder, *rel.parts[1:])
|
|
90
|
+
else:
|
|
91
|
+
rel = Path(self.dataset_folder) / rel
|
|
92
|
+
self.filecache = self.cache_dir / rel
|
|
93
|
+
self.bids_root = self.cache_dir / self.dataset_folder
|
|
71
94
|
self.bidspath = BIDSPath(
|
|
72
95
|
root=self.bids_root,
|
|
73
96
|
datatype="eeg",
|
|
@@ -75,7 +98,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
75
98
|
**self.bids_kwargs,
|
|
76
99
|
)
|
|
77
100
|
|
|
78
|
-
self.s3file = self.
|
|
101
|
+
self.s3file = self._get_s3path(record["bidspath"])
|
|
79
102
|
self.bids_dependencies = record["bidsdependencies"]
|
|
80
103
|
# Temporary fix for BIDS dependencies path
|
|
81
104
|
# just to release to the competition
|
|
@@ -87,7 +110,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
87
110
|
|
|
88
111
|
self._raw = None
|
|
89
112
|
|
|
90
|
-
def
|
|
113
|
+
def _get_s3path(self, filepath: str) -> str:
|
|
91
114
|
"""Helper to form an AWS S3 URI for the given relative filepath."""
|
|
92
115
|
return f"{self.s3_bucket}/{filepath}"
|
|
93
116
|
|
|
@@ -141,14 +164,22 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
141
164
|
if dep.endswith(".set"):
|
|
142
165
|
dep = dep[:-4] + ".bdf"
|
|
143
166
|
|
|
144
|
-
s3path = self.
|
|
167
|
+
s3path = self._get_s3path(dep)
|
|
145
168
|
if not self.s3_open_neuro:
|
|
146
169
|
dep = self.bids_dependencies_original[i]
|
|
147
170
|
|
|
148
|
-
|
|
171
|
+
dep_path = Path(dep)
|
|
172
|
+
if dep_path.parts and dep_path.parts[0] == self.record.get("dataset"):
|
|
173
|
+
dep_local = Path(self.dataset_folder, *dep_path.parts[1:])
|
|
174
|
+
else:
|
|
175
|
+
dep_local = Path(self.dataset_folder) / dep_path
|
|
176
|
+
filepath = self.cache_dir / dep_local
|
|
149
177
|
if not self.s3_open_neuro:
|
|
178
|
+
if filepath.suffix == ".set":
|
|
179
|
+
filepath = filepath.with_suffix(".bdf")
|
|
150
180
|
if self.filecache.suffix == ".set":
|
|
151
181
|
self.filecache = self.filecache.with_suffix(".bdf")
|
|
182
|
+
|
|
152
183
|
# here, we download the dependency and it is fine
|
|
153
184
|
# in the case of the competition.
|
|
154
185
|
if not filepath.exists():
|
|
@@ -174,15 +205,21 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
174
205
|
)
|
|
175
206
|
filesystem.get(s3path, filepath, callback=callback)
|
|
176
207
|
|
|
177
|
-
def
|
|
208
|
+
def _get_raw_bids_args(self) -> dict[str, Any]:
|
|
178
209
|
"""Helper to restrict the metadata record to the fields needed to locate a BIDS
|
|
179
210
|
recording.
|
|
180
211
|
"""
|
|
181
212
|
desired_fields = ["subject", "session", "task", "run"]
|
|
182
213
|
return {k: self.record[k] for k in desired_fields if self.record[k]}
|
|
183
214
|
|
|
184
|
-
def
|
|
215
|
+
def _ensure_raw(self) -> None:
|
|
185
216
|
"""Download the S3 file and BIDS dependencies if not already cached."""
|
|
217
|
+
# TO-DO: remove this once is fixed on the our side
|
|
218
|
+
# for the competition
|
|
219
|
+
if not self.s3_open_neuro:
|
|
220
|
+
self.bidspath = self.bidspath.update(extension=".bdf")
|
|
221
|
+
self.filecache = self.filecache.with_suffix(".bdf")
|
|
222
|
+
|
|
186
223
|
if not os.path.exists(self.filecache): # not preload
|
|
187
224
|
if self.bids_dependencies:
|
|
188
225
|
self._download_dependencies()
|
|
@@ -191,14 +228,48 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
191
228
|
# capturing any warnings
|
|
192
229
|
# to-do: remove this once is fixed on the mne-bids side.
|
|
193
230
|
with warnings.catch_warnings(record=True) as w:
|
|
231
|
+
# Ensure all warnings are captured into 'w' and not shown to users
|
|
232
|
+
warnings.simplefilter("always")
|
|
194
233
|
try:
|
|
195
|
-
#
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
234
|
+
# mne-bids emits RuntimeWarnings to stderr; silence stderr during read
|
|
235
|
+
_stderr_buffer = io.StringIO()
|
|
236
|
+
with redirect_stderr(_stderr_buffer):
|
|
237
|
+
self._raw = mne_bids.read_raw_bids(
|
|
238
|
+
bids_path=self.bidspath, verbose="ERROR"
|
|
239
|
+
)
|
|
240
|
+
# Parse unmapped participants.tsv fields reported by mne-bids and
|
|
241
|
+
# inject them into Raw.info and the dataset description generically.
|
|
242
|
+
extras = self._extract_unmapped_participants_from_warnings(w)
|
|
243
|
+
if extras:
|
|
244
|
+
# 1) Attach to Raw.info under subject_info.participants_extras
|
|
245
|
+
try:
|
|
246
|
+
subject_info = self._raw.info.get("subject_info") or {}
|
|
247
|
+
if not isinstance(subject_info, dict):
|
|
248
|
+
subject_info = {}
|
|
249
|
+
pe = subject_info.get("participants_extras") or {}
|
|
250
|
+
if not isinstance(pe, dict):
|
|
251
|
+
pe = {}
|
|
252
|
+
# Merge without overwriting
|
|
253
|
+
for k, v in extras.items():
|
|
254
|
+
pe.setdefault(k, v)
|
|
255
|
+
subject_info["participants_extras"] = pe
|
|
256
|
+
self._raw.info["subject_info"] = subject_info
|
|
257
|
+
except Exception:
|
|
258
|
+
# Non-fatal; continue
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
# 2) Also add to this dataset's description, if possible, so
|
|
262
|
+
# targets can be selected later without naming specifics.
|
|
263
|
+
try:
|
|
264
|
+
if isinstance(self.description, dict):
|
|
265
|
+
for k, v in extras.items():
|
|
266
|
+
self.description.setdefault(k, v)
|
|
267
|
+
elif isinstance(self.description, pd.Series):
|
|
268
|
+
for k, v in extras.items():
|
|
269
|
+
if k not in self.description.index:
|
|
270
|
+
self.description.loc[k] = v
|
|
271
|
+
except Exception:
|
|
272
|
+
pass
|
|
202
273
|
except Exception as e:
|
|
203
274
|
logger.error(
|
|
204
275
|
f"Error while reading BIDS file: {self.bidspath}\n"
|
|
@@ -208,10 +279,59 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
208
279
|
logger.error(f"Exception: {e}")
|
|
209
280
|
logger.error(traceback.format_exc())
|
|
210
281
|
raise e
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
282
|
+
# Filter noisy mapping notices from mne-bids; surface others
|
|
283
|
+
for captured_warning in w:
|
|
284
|
+
try:
|
|
285
|
+
msg = str(captured_warning.message)
|
|
286
|
+
except Exception:
|
|
287
|
+
continue
|
|
288
|
+
# Suppress verbose participants mapping messages
|
|
289
|
+
if "Unable to map the following column" in msg and "MNE" in msg:
|
|
290
|
+
logger.debug(
|
|
291
|
+
"Suppressed mne-bids mapping warning while reading BIDS file: %s",
|
|
292
|
+
msg,
|
|
293
|
+
)
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
def _extract_unmapped_participants_from_warnings(
|
|
297
|
+
self, warnings_list: list[Any]
|
|
298
|
+
) -> dict[str, Any]:
|
|
299
|
+
"""Scan captured warnings from mne-bids and extract unmapped participants.tsv
|
|
300
|
+
entries in a generic way.
|
|
301
|
+
|
|
302
|
+
Optionally, the column name can carry a note in parentheses that we ignore
|
|
303
|
+
for key/value extraction. Returns a mapping of column name -> raw value.
|
|
304
|
+
"""
|
|
305
|
+
extras: dict[str, Any] = {}
|
|
306
|
+
header = "Unable to map the following column(s) to MNE:"
|
|
307
|
+
for wr in warnings_list:
|
|
308
|
+
try:
|
|
309
|
+
msg = str(wr.message)
|
|
310
|
+
except Exception:
|
|
311
|
+
continue
|
|
312
|
+
if header not in msg:
|
|
313
|
+
continue
|
|
314
|
+
lines = msg.splitlines()
|
|
315
|
+
# Find the header line, then parse subsequent lines as entries
|
|
316
|
+
try:
|
|
317
|
+
idx = next(i for i, ln in enumerate(lines) if header in ln)
|
|
318
|
+
except StopIteration:
|
|
319
|
+
idx = -1
|
|
320
|
+
for line in lines[idx + 1 :]:
|
|
321
|
+
line = line.strip()
|
|
322
|
+
if not line:
|
|
323
|
+
continue
|
|
324
|
+
# Pattern: <col>(optional note): <value>
|
|
325
|
+
# Examples: "gender: F", "Ethnicity: Indian", "foo (ignored): bar"
|
|
326
|
+
m = re.match(r"^([^:]+?)(?:\s*\([^)]*\))?\s*:\s*(.*)$", line)
|
|
327
|
+
if not m:
|
|
328
|
+
continue
|
|
329
|
+
col = m.group(1).strip()
|
|
330
|
+
val = m.group(2).strip()
|
|
331
|
+
# Keep original column names as provided to stay agnostic
|
|
332
|
+
if col and col not in extras:
|
|
333
|
+
extras[col] = val
|
|
334
|
+
return extras
|
|
215
335
|
|
|
216
336
|
# === BaseDataset and PyTorch Dataset interface ===
|
|
217
337
|
|
|
@@ -230,11 +350,16 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
230
350
|
def __len__(self) -> int:
|
|
231
351
|
"""Return the number of samples in the dataset."""
|
|
232
352
|
if self._raw is None:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
353
|
+
if (
|
|
354
|
+
self.record["ntimes"] is None
|
|
355
|
+
or self.record["sampling_frequency"] is None
|
|
356
|
+
):
|
|
357
|
+
self._ensure_raw()
|
|
358
|
+
else:
|
|
359
|
+
# FIXME: this is a bit strange and should definitely not change as a side effect
|
|
360
|
+
# of accessing the data (which it will, since ntimes is the actual length but rounded down)
|
|
361
|
+
return int(self.record["ntimes"] * self.record["sampling_frequency"])
|
|
362
|
+
return len(self._raw)
|
|
238
363
|
|
|
239
364
|
@property
|
|
240
365
|
def raw(self):
|
|
@@ -242,7 +367,7 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
242
367
|
retrieval if not yet done so.
|
|
243
368
|
"""
|
|
244
369
|
if self._raw is None:
|
|
245
|
-
self.
|
|
370
|
+
self._ensure_raw()
|
|
246
371
|
return self._raw
|
|
247
372
|
|
|
248
373
|
@raw.setter
|
|
@@ -284,7 +409,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
284
409
|
metadata: dict[str, Any],
|
|
285
410
|
preload: bool = False,
|
|
286
411
|
*,
|
|
287
|
-
cache_dir: str =
|
|
412
|
+
cache_dir: str | None = None,
|
|
288
413
|
bids_dependencies: list[str] = [],
|
|
289
414
|
verbose: Any = None,
|
|
290
415
|
):
|
|
@@ -300,8 +425,9 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
300
425
|
chtype = "eog"
|
|
301
426
|
ch_types.append(chtype)
|
|
302
427
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
303
|
-
|
|
304
|
-
self.
|
|
428
|
+
|
|
429
|
+
self.s3file = self._get_s3path(input_fname)
|
|
430
|
+
self.cache_dir = Path(cache_dir) if cache_dir else get_default_cache_dir()
|
|
305
431
|
self.filecache = self.cache_dir / input_fname
|
|
306
432
|
self.bids_dependencies = bids_dependencies
|
|
307
433
|
|
|
@@ -317,7 +443,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
317
443
|
verbose=verbose,
|
|
318
444
|
)
|
|
319
445
|
|
|
320
|
-
def
|
|
446
|
+
def _get_s3path(self, filepath):
|
|
321
447
|
return f"{self._AWS_BUCKET}/{filepath}"
|
|
322
448
|
|
|
323
449
|
def _download_s3(self) -> None:
|
|
@@ -333,7 +459,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
333
459
|
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
334
460
|
)
|
|
335
461
|
for dep in self.bids_dependencies:
|
|
336
|
-
s3path = self.
|
|
462
|
+
s3path = self._get_s3path(dep)
|
|
337
463
|
filepath = self.cache_dir / dep
|
|
338
464
|
if not filepath.exists():
|
|
339
465
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -394,11 +520,17 @@ class EEGBIDSDataset:
|
|
|
394
520
|
raise ValueError("data_dir must be specified and must exist")
|
|
395
521
|
self.bidsdir = Path(data_dir)
|
|
396
522
|
self.dataset = dataset
|
|
397
|
-
|
|
523
|
+
# Accept exact dataset folder or a variant with informative suffixes
|
|
524
|
+
# (e.g., dsXXXXX-bdf, dsXXXXX-bdf-mini) to avoid collisions.
|
|
525
|
+
dir_name = self.bidsdir.name
|
|
526
|
+
if not (dir_name == self.dataset or dir_name.startswith(self.dataset + "-")):
|
|
527
|
+
raise AssertionError(
|
|
528
|
+
f"BIDS directory '{dir_name}' does not correspond to dataset '{self.dataset}'"
|
|
529
|
+
)
|
|
398
530
|
self.layout = BIDSLayout(data_dir)
|
|
399
531
|
|
|
400
532
|
# get all recording files in the bids directory
|
|
401
|
-
self.files = self.
|
|
533
|
+
self.files = self._get_recordings(self.layout)
|
|
402
534
|
assert len(self.files) > 0, ValueError(
|
|
403
535
|
"Unable to construct EEG dataset. No EEG recordings found."
|
|
404
536
|
)
|
|
@@ -408,7 +540,7 @@ class EEGBIDSDataset:
|
|
|
408
540
|
"""Check if the dataset is EEG."""
|
|
409
541
|
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
|
|
410
542
|
|
|
411
|
-
def
|
|
543
|
+
def _get_recordings(self, layout: BIDSLayout) -> list[str]:
|
|
412
544
|
"""Get a list of all EEG recording files in the BIDS layout."""
|
|
413
545
|
files = []
|
|
414
546
|
for ext, exts in self.RAW_EXTENSIONS.items():
|
|
@@ -417,12 +549,12 @@ class EEGBIDSDataset:
|
|
|
417
549
|
break
|
|
418
550
|
return files
|
|
419
551
|
|
|
420
|
-
def
|
|
552
|
+
def _get_relative_bidspath(self, filename: str) -> str:
|
|
421
553
|
"""Make the given file path relative to the BIDS directory."""
|
|
422
554
|
bids_parent_dir = self.bidsdir.parent.absolute()
|
|
423
555
|
return str(Path(filename).relative_to(bids_parent_dir))
|
|
424
556
|
|
|
425
|
-
def
|
|
557
|
+
def _get_property_from_filename(self, property: str, filename: str) -> str:
|
|
426
558
|
"""Parse a property out of a BIDS-compliant filename. Returns an empty string
|
|
427
559
|
if not found.
|
|
428
560
|
"""
|
|
@@ -434,7 +566,7 @@ class EEGBIDSDataset:
|
|
|
434
566
|
lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
|
|
435
567
|
return lookup.group(1) if lookup else ""
|
|
436
568
|
|
|
437
|
-
def
|
|
569
|
+
def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
|
|
438
570
|
"""Internal helper to merge list of json files found by get_bids_file_inheritance,
|
|
439
571
|
expecting the order (from left to right) is from lowest
|
|
440
572
|
level to highest level, and return a merged dictionary
|
|
@@ -445,7 +577,7 @@ class EEGBIDSDataset:
|
|
|
445
577
|
json_dict.update(json.load(open(f))) # FIXME: should close file
|
|
446
578
|
return json_dict
|
|
447
579
|
|
|
448
|
-
def
|
|
580
|
+
def _get_bids_file_inheritance(
|
|
449
581
|
self, path: str | Path, basename: str, extension: str
|
|
450
582
|
) -> list[Path]:
|
|
451
583
|
"""Get all file paths that apply to the basename file in the specified directory
|
|
@@ -492,7 +624,7 @@ class EEGBIDSDataset:
|
|
|
492
624
|
else:
|
|
493
625
|
# call get_bids_file_inheritance recursively with parent directory
|
|
494
626
|
bids_files.extend(
|
|
495
|
-
self.
|
|
627
|
+
self._get_bids_file_inheritance(path.parent, basename, extension)
|
|
496
628
|
)
|
|
497
629
|
return bids_files
|
|
498
630
|
|
|
@@ -523,12 +655,12 @@ class EEGBIDSDataset:
|
|
|
523
655
|
path, filename = os.path.split(filepath)
|
|
524
656
|
basename = filename[: filename.rfind("_")]
|
|
525
657
|
# metadata files
|
|
526
|
-
meta_files = self.
|
|
658
|
+
meta_files = self._get_bids_file_inheritance(
|
|
527
659
|
path, basename, metadata_file_extension
|
|
528
660
|
)
|
|
529
661
|
return meta_files
|
|
530
662
|
|
|
531
|
-
def
|
|
663
|
+
def _scan_directory(self, directory: str, extension: str) -> list[Path]:
|
|
532
664
|
"""Return a list of file paths that end with the given extension in the specified
|
|
533
665
|
directory. Ignores certain special directories like .git, .datalad, derivatives,
|
|
534
666
|
and code.
|
|
@@ -545,7 +677,7 @@ class EEGBIDSDataset:
|
|
|
545
677
|
result_files.append(entry.path) # Add directory to scan later
|
|
546
678
|
return result_files
|
|
547
679
|
|
|
548
|
-
def
|
|
680
|
+
def _get_files_with_extension_parallel(
|
|
549
681
|
self, directory: str, extension: str = ".set", max_workers: int = -1
|
|
550
682
|
) -> list[Path]:
|
|
551
683
|
"""Efficiently scan a directory and its subdirectories for files that end with
|
|
@@ -577,7 +709,7 @@ class EEGBIDSDataset:
|
|
|
577
709
|
)
|
|
578
710
|
# Run the scan_directory function in parallel across directories
|
|
579
711
|
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
|
|
580
|
-
delayed(self.
|
|
712
|
+
delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
|
|
581
713
|
)
|
|
582
714
|
|
|
583
715
|
# Reset the directories to scan and process the results
|
|
@@ -682,7 +814,7 @@ class EEGBIDSDataset:
|
|
|
682
814
|
def num_times(self, data_filepath: str) -> int:
|
|
683
815
|
"""Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
|
|
684
816
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
685
|
-
eeg_json_dict = self.
|
|
817
|
+
eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
|
|
686
818
|
return int(
|
|
687
819
|
eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
|
|
688
820
|
)
|
|
@@ -705,7 +837,7 @@ class EEGBIDSDataset:
|
|
|
705
837
|
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
|
|
706
838
|
"""Get BIDS eeg.json metadata for the given data file path."""
|
|
707
839
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
708
|
-
eeg_json_dict = self.
|
|
840
|
+
eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
|
|
709
841
|
return eeg_json_dict
|
|
710
842
|
|
|
711
843
|
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
@@ -3,8 +3,9 @@ from pathlib import Path
|
|
|
3
3
|
|
|
4
4
|
from mne.utils import warn
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
6
|
+
from ..api import EEGDashDataset
|
|
7
|
+
from ..bids_eeg_metadata import build_query_from_kwargs
|
|
8
|
+
from ..const import RELEASE_TO_OPENNEURO_DATASET_MAP, SUBJECT_MINI_RELEASE_MAP
|
|
8
9
|
from .registry import register_openneuro_datasets
|
|
9
10
|
|
|
10
11
|
logger = logging.getLogger("eegdash")
|
|
@@ -68,15 +69,56 @@ class EEGChallengeDataset(EEGDashDataset):
|
|
|
68
69
|
)
|
|
69
70
|
|
|
70
71
|
if self.mini:
|
|
71
|
-
#
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
72
|
+
# When using the mini release, restrict subjects to the predefined subset.
|
|
73
|
+
# If the user specifies subject(s), ensure they all belong to the mini subset;
|
|
74
|
+
# otherwise, default to the full mini subject list for this release.
|
|
75
|
+
|
|
76
|
+
allowed_subjects = set(SUBJECT_MINI_RELEASE_MAP[release])
|
|
77
|
+
|
|
78
|
+
# Normalize potential 'subjects' -> 'subject' for convenience
|
|
79
|
+
if "subjects" in kwargs and "subject" not in kwargs:
|
|
80
|
+
kwargs["subject"] = kwargs.pop("subjects")
|
|
81
|
+
|
|
82
|
+
# Collect user-requested subjects from kwargs/query. We canonicalize
|
|
83
|
+
# kwargs via build_query_from_kwargs to leverage existing validation,
|
|
84
|
+
# and support Mongo-style {"$in": [...]} shapes from a raw query.
|
|
85
|
+
requested_subjects: list[str] = []
|
|
86
|
+
|
|
87
|
+
# From kwargs
|
|
88
|
+
if "subject" in kwargs and kwargs["subject"] is not None:
|
|
89
|
+
# Use the shared query builder to normalize scalars/lists
|
|
90
|
+
built = build_query_from_kwargs(subject=kwargs["subject"])
|
|
91
|
+
s_val = built.get("subject")
|
|
92
|
+
if isinstance(s_val, dict) and "$in" in s_val:
|
|
93
|
+
requested_subjects.extend(list(s_val["$in"]))
|
|
94
|
+
elif s_val is not None:
|
|
95
|
+
requested_subjects.append(s_val) # type: ignore[arg-type]
|
|
96
|
+
|
|
97
|
+
# From query (top-level only)
|
|
98
|
+
if query and isinstance(query, dict) and "subject" in query:
|
|
99
|
+
qval = query["subject"]
|
|
100
|
+
if isinstance(qval, dict) and "$in" in qval:
|
|
101
|
+
requested_subjects.extend(list(qval["$in"]))
|
|
102
|
+
elif isinstance(qval, (list, tuple, set)):
|
|
103
|
+
requested_subjects.extend(list(qval))
|
|
104
|
+
elif qval is not None:
|
|
105
|
+
requested_subjects.append(qval)
|
|
106
|
+
|
|
107
|
+
# Validate if any subjects were explicitly requested
|
|
108
|
+
if requested_subjects:
|
|
109
|
+
invalid = sorted(
|
|
110
|
+
{s for s in requested_subjects if s not in allowed_subjects}
|
|
78
111
|
)
|
|
79
|
-
|
|
112
|
+
if invalid:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"Some requested subject(s) are not part of the mini release for "
|
|
115
|
+
f"{release}: {invalid}. Allowed subjects: {sorted(allowed_subjects)}"
|
|
116
|
+
)
|
|
117
|
+
# Do not override user selection; keep their (validated) subjects as-is.
|
|
118
|
+
else:
|
|
119
|
+
# No subject specified by the user: default to the full mini subset
|
|
120
|
+
kwargs["subject"] = sorted(allowed_subjects)
|
|
121
|
+
|
|
80
122
|
s3_bucket = f"{s3_bucket}/{release}_mini_L100_bdf"
|
|
81
123
|
else:
|
|
82
124
|
s3_bucket = f"{s3_bucket}/{release}_L100_bdf"
|
|
@@ -104,6 +146,7 @@ class EEGChallengeDataset(EEGDashDataset):
|
|
|
104
146
|
query=query,
|
|
105
147
|
cache_dir=cache_dir,
|
|
106
148
|
s3_bucket=s3_bucket,
|
|
149
|
+
_suppress_comp_warning=True,
|
|
107
150
|
**kwargs,
|
|
108
151
|
)
|
|
109
152
|
|