eegdash 0.3.7.dev183881899__py3-none-any.whl → 0.3.9.dev114__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +5 -5
- eegdash/api.py +533 -467
- eegdash/bids_eeg_metadata.py +254 -0
- eegdash/const.py +48 -0
- eegdash/data_utils.py +177 -45
- eegdash/dataset/__init__.py +4 -0
- eegdash/dataset/dataset.py +161 -0
- eegdash/dataset/dataset_summary.csv +256 -0
- eegdash/{registry.py → dataset/registry.py} +9 -20
- eegdash/hbn/__init__.py +17 -0
- eegdash/hbn/windows.py +305 -0
- eegdash/paths.py +28 -0
- eegdash/utils.py +1 -1
- {eegdash-0.3.7.dev183881899.dist-info → eegdash-0.3.9.dev114.dist-info}/METADATA +12 -5
- eegdash-0.3.9.dev114.dist-info/RECORD +35 -0
- eegdash/data_config.py +0 -34
- eegdash/dataset.py +0 -118
- eegdash/dataset_summary.csv +0 -256
- eegdash-0.3.7.dev183881899.dist-info/RECORD +0 -31
- /eegdash/{preprocessing.py → hbn/preprocessing.py} +0 -0
- {eegdash-0.3.7.dev183881899.dist-info → eegdash-0.3.9.dev114.dist-info}/WHEEL +0 -0
- {eegdash-0.3.7.dev183881899.dist-info → eegdash-0.3.9.dev114.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.3.7.dev183881899.dist-info → eegdash-0.3.9.dev114.dist-info}/top_level.txt +0 -0
eegdash/api.py
CHANGED
|
@@ -3,69 +3,68 @@ import os
|
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Any, Mapping
|
|
6
|
+
from urllib.parse import urlsplit
|
|
6
7
|
|
|
7
8
|
import mne
|
|
8
9
|
import numpy as np
|
|
9
|
-
import platformdirs
|
|
10
10
|
import xarray as xr
|
|
11
|
+
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
|
|
11
12
|
from dotenv import load_dotenv
|
|
12
13
|
from joblib import Parallel, delayed
|
|
13
14
|
from mne.utils import warn
|
|
14
|
-
from mne_bids import get_bids_path_from_fname, read_raw_bids
|
|
15
|
+
from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
|
|
15
16
|
from pymongo import InsertOne, UpdateOne
|
|
16
17
|
from s3fs import S3FileSystem
|
|
17
18
|
|
|
18
19
|
from braindecode.datasets import BaseConcatDataset
|
|
19
20
|
|
|
20
|
-
from .
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
from .bids_eeg_metadata import (
|
|
22
|
+
build_query_from_kwargs,
|
|
23
|
+
load_eeg_attrs_from_bids_file,
|
|
24
|
+
merge_participants_fields,
|
|
25
|
+
normalize_key,
|
|
26
|
+
)
|
|
27
|
+
from .const import (
|
|
28
|
+
ALLOWED_QUERY_FIELDS,
|
|
29
|
+
RELEASE_TO_OPENNEURO_DATASET_MAP,
|
|
30
|
+
)
|
|
31
|
+
from .const import config as data_config
|
|
32
|
+
from .data_utils import (
|
|
33
|
+
EEGBIDSDataset,
|
|
34
|
+
EEGDashBaseDataset,
|
|
35
|
+
)
|
|
23
36
|
from .mongodb import MongoConnectionManager
|
|
37
|
+
from .paths import get_default_cache_dir
|
|
24
38
|
|
|
25
39
|
logger = logging.getLogger("eegdash")
|
|
26
40
|
|
|
27
41
|
|
|
28
42
|
class EEGDash:
|
|
29
|
-
"""
|
|
43
|
+
"""High-level interface to the EEGDash metadata database.
|
|
30
44
|
|
|
31
|
-
|
|
32
|
-
EEGDash database (or
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
While this class provides basic support for loading EEG data, please see
|
|
36
|
-
the EEGDashDataset class for a more complete way to retrieve and work with full
|
|
37
|
-
datasets.
|
|
45
|
+
Provides methods to query, insert, and update metadata records stored in the
|
|
46
|
+
EEGDash MongoDB database (public or private). Also includes utilities to load
|
|
47
|
+
EEG data from S3 for matched records.
|
|
38
48
|
|
|
49
|
+
For working with collections of
|
|
50
|
+
recordings as PyTorch datasets, prefer :class:`EEGDashDataset`.
|
|
39
51
|
"""
|
|
40
52
|
|
|
41
|
-
_ALLOWED_QUERY_FIELDS = {
|
|
42
|
-
"data_name",
|
|
43
|
-
"dataset",
|
|
44
|
-
"subject",
|
|
45
|
-
"task",
|
|
46
|
-
"session",
|
|
47
|
-
"run",
|
|
48
|
-
"modality",
|
|
49
|
-
"sampling_frequency",
|
|
50
|
-
"nchans",
|
|
51
|
-
"ntimes",
|
|
52
|
-
}
|
|
53
|
-
|
|
54
53
|
def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
|
|
55
|
-
"""Create new
|
|
54
|
+
"""Create a new EEGDash client.
|
|
56
55
|
|
|
57
56
|
Parameters
|
|
58
57
|
----------
|
|
59
|
-
is_public: bool
|
|
60
|
-
|
|
61
|
-
private database instance
|
|
62
|
-
(or
|
|
63
|
-
is_staging: bool
|
|
64
|
-
If True
|
|
65
|
-
production database (
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
58
|
+
is_public : bool, default True
|
|
59
|
+
Connect to the public MongoDB database. If ``False``, connect to a
|
|
60
|
+
private database instance using the ``DB_CONNECTION_STRING`` environment
|
|
61
|
+
variable (or value from a ``.env`` file).
|
|
62
|
+
is_staging : bool, default False
|
|
63
|
+
If ``True``, use the staging database (``eegdashstaging``); otherwise
|
|
64
|
+
use the production database (``eegdash``).
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
69
68
|
>>> eegdash = EEGDash()
|
|
70
69
|
|
|
71
70
|
"""
|
|
@@ -93,36 +92,35 @@ class EEGDash:
|
|
|
93
92
|
) -> list[Mapping[str, Any]]:
|
|
94
93
|
"""Find records in the MongoDB collection.
|
|
95
94
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
>>> eegdash.find({}) # fetches all records (use with care)
|
|
104
|
-
4. By combining a raw query with kwargs (merged via logical AND):
|
|
105
|
-
>>> eegdash.find({"dataset": "ds002718"}, subject=["012", "013"]) # yields {"$and":[{"dataset":"ds002718"}, {"subject":{"$in":["012","013"]}}]}
|
|
95
|
+
Examples
|
|
96
|
+
--------
|
|
97
|
+
>>> eegdash.find({"dataset": "ds002718", "subject": {"$in": ["012", "013"]}}) # pre-built query
|
|
98
|
+
>>> eegdash.find(dataset="ds002718", subject="012") # keyword filters
|
|
99
|
+
>>> eegdash.find(dataset="ds002718", subject=["012", "013"]) # sequence -> $in
|
|
100
|
+
>>> eegdash.find({}) # fetch all (use with care)
|
|
101
|
+
>>> eegdash.find({"dataset": "ds002718"}, subject=["012", "013"]) # combine query + kwargs (AND)
|
|
106
102
|
|
|
107
103
|
Parameters
|
|
108
104
|
----------
|
|
109
|
-
query: dict, optional
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
105
|
+
query : dict, optional
|
|
106
|
+
Complete MongoDB query dictionary. This is a positional-only
|
|
107
|
+
argument.
|
|
108
|
+
**kwargs
|
|
109
|
+
User-friendly field filters that are converted to a MongoDB query.
|
|
110
|
+
Values can be scalars (e.g., ``"sub-01"``) or sequences (translated
|
|
111
|
+
to ``$in`` queries).
|
|
114
112
|
|
|
115
113
|
Returns
|
|
116
114
|
-------
|
|
117
|
-
list
|
|
118
|
-
|
|
115
|
+
list of dict
|
|
116
|
+
DB records that match the query.
|
|
119
117
|
|
|
120
118
|
"""
|
|
121
119
|
final_query: dict[str, Any] | None = None
|
|
122
120
|
|
|
123
121
|
# Accept explicit empty dict {} to mean "match all"
|
|
124
122
|
raw_query = query if isinstance(query, dict) else None
|
|
125
|
-
kwargs_query =
|
|
123
|
+
kwargs_query = build_query_from_kwargs(**kwargs) if kwargs else None
|
|
126
124
|
|
|
127
125
|
# Determine presence, treating {} as a valid raw query
|
|
128
126
|
has_raw = isinstance(raw_query, dict)
|
|
@@ -239,59 +237,12 @@ class EEGDash:
|
|
|
239
237
|
return record
|
|
240
238
|
|
|
241
239
|
def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
|
|
242
|
-
"""
|
|
240
|
+
"""Internal helper to build a validated MongoDB query from keyword args.
|
|
243
241
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
- For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
|
|
247
|
-
- Preserve scalars as exact matches
|
|
242
|
+
This delegates to the module-level builder used across the package and
|
|
243
|
+
is exposed here for testing and convenience.
|
|
248
244
|
"""
|
|
249
|
-
|
|
250
|
-
unknown_fields = set(kwargs.keys()) - self._ALLOWED_QUERY_FIELDS
|
|
251
|
-
if unknown_fields:
|
|
252
|
-
raise ValueError(
|
|
253
|
-
f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
|
|
254
|
-
f"Allowed fields are: {', '.join(sorted(self._ALLOWED_QUERY_FIELDS))}"
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# 2. Construct the query dictionary
|
|
258
|
-
query = {}
|
|
259
|
-
for key, value in kwargs.items():
|
|
260
|
-
# None is not a valid constraint
|
|
261
|
-
if value is None:
|
|
262
|
-
raise ValueError(
|
|
263
|
-
f"Received None for query parameter '{key}'. Provide a concrete value."
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
# Handle list-like values as multi-constraints
|
|
267
|
-
if isinstance(value, (list, tuple, set)):
|
|
268
|
-
cleaned: list[Any] = []
|
|
269
|
-
for item in value:
|
|
270
|
-
if item is None:
|
|
271
|
-
continue
|
|
272
|
-
if isinstance(item, str):
|
|
273
|
-
item = item.strip()
|
|
274
|
-
if not item:
|
|
275
|
-
continue
|
|
276
|
-
cleaned.append(item)
|
|
277
|
-
# Deduplicate while preserving order
|
|
278
|
-
cleaned = list(dict.fromkeys(cleaned))
|
|
279
|
-
if not cleaned:
|
|
280
|
-
raise ValueError(
|
|
281
|
-
f"Received an empty list for query parameter '{key}'. This is not supported."
|
|
282
|
-
)
|
|
283
|
-
query[key] = {"$in": cleaned}
|
|
284
|
-
else:
|
|
285
|
-
# Scalars: trim strings and validate
|
|
286
|
-
if isinstance(value, str):
|
|
287
|
-
value = value.strip()
|
|
288
|
-
if not value:
|
|
289
|
-
raise ValueError(
|
|
290
|
-
f"Received an empty string for query parameter '{key}'."
|
|
291
|
-
)
|
|
292
|
-
query[key] = value
|
|
293
|
-
|
|
294
|
-
return query
|
|
245
|
+
return build_query_from_kwargs(**kwargs)
|
|
295
246
|
|
|
296
247
|
# --- Query merging and conflict detection helpers ---
|
|
297
248
|
def _extract_simple_constraint(self, query: dict[str, Any], key: str):
|
|
@@ -324,8 +275,8 @@ class EEGDash:
|
|
|
324
275
|
return
|
|
325
276
|
|
|
326
277
|
# Only consider fields we generally allow; skip meta operators like $and
|
|
327
|
-
raw_keys = set(raw_query.keys()) &
|
|
328
|
-
kw_keys = set(kwargs_query.keys()) &
|
|
278
|
+
raw_keys = set(raw_query.keys()) & ALLOWED_QUERY_FIELDS
|
|
279
|
+
kw_keys = set(kwargs_query.keys()) & ALLOWED_QUERY_FIELDS
|
|
329
280
|
dup_keys = raw_keys & kw_keys
|
|
330
281
|
for key in dup_keys:
|
|
331
282
|
rc = self._extract_simple_constraint(raw_query, key)
|
|
@@ -360,44 +311,95 @@ class EEGDash:
|
|
|
360
311
|
)
|
|
361
312
|
|
|
362
313
|
def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
|
|
363
|
-
"""Load
|
|
314
|
+
"""Load EEG data from an S3 URI into an ``xarray.DataArray``.
|
|
315
|
+
|
|
316
|
+
Preserves the original filename, downloads sidecar files when applicable
|
|
317
|
+
(e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
|
|
318
|
+
MNE's direct readers.
|
|
364
319
|
|
|
365
320
|
Parameters
|
|
366
321
|
----------
|
|
367
322
|
s3path : str
|
|
368
|
-
An S3 URI (should start with "s3://")
|
|
323
|
+
An S3 URI (should start with "s3://").
|
|
369
324
|
|
|
370
325
|
Returns
|
|
371
326
|
-------
|
|
372
327
|
xr.DataArray
|
|
373
|
-
|
|
328
|
+
EEG data with dimensions ``("channel", "time")``.
|
|
374
329
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
>>> mydata = eegdash.load_eeg_data_from_s3(mypath)
|
|
330
|
+
Raises
|
|
331
|
+
------
|
|
332
|
+
ValueError
|
|
333
|
+
If the file extension is unsupported.
|
|
380
334
|
|
|
381
335
|
"""
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
336
|
+
# choose a temp dir so sidecars can be colocated
|
|
337
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
338
|
+
# Derive local filenames from the S3 key to keep base name consistent
|
|
339
|
+
s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
|
|
340
|
+
basename = Path(s3_key).name
|
|
341
|
+
ext = Path(basename).suffix.lower()
|
|
342
|
+
local_main = Path(tmpdir) / basename
|
|
343
|
+
|
|
344
|
+
# Download main file
|
|
345
|
+
with (
|
|
346
|
+
self.filesystem.open(s3path, mode="rb") as fsrc,
|
|
347
|
+
open(local_main, "wb") as fdst,
|
|
348
|
+
):
|
|
349
|
+
fdst.write(fsrc.read())
|
|
350
|
+
|
|
351
|
+
# Determine and fetch any required sidecars
|
|
352
|
+
sidecars: list[str] = []
|
|
353
|
+
if ext == ".set": # EEGLAB
|
|
354
|
+
sidecars = [".fdt"]
|
|
355
|
+
elif ext == ".vhdr": # BrainVision
|
|
356
|
+
sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
|
|
357
|
+
|
|
358
|
+
for sc_ext in sidecars:
|
|
359
|
+
sc_key = s3_key[: -len(ext)] + sc_ext
|
|
360
|
+
sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
|
|
361
|
+
try:
|
|
362
|
+
# If sidecar exists, download next to the main file
|
|
363
|
+
info = self.filesystem.info(sc_uri)
|
|
364
|
+
if info:
|
|
365
|
+
sc_local = Path(tmpdir) / Path(sc_key).name
|
|
366
|
+
with (
|
|
367
|
+
self.filesystem.open(sc_uri, mode="rb") as fsrc,
|
|
368
|
+
open(sc_local, "wb") as fdst,
|
|
369
|
+
):
|
|
370
|
+
fdst.write(fsrc.read())
|
|
371
|
+
except Exception:
|
|
372
|
+
# Sidecar not present; skip silently
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
# Read using appropriate MNE reader
|
|
376
|
+
raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
|
|
377
|
+
|
|
378
|
+
data = raw.get_data()
|
|
379
|
+
fs = raw.info["sfreq"]
|
|
380
|
+
max_time = data.shape[1] / fs
|
|
381
|
+
time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
|
|
382
|
+
channel_names = raw.ch_names
|
|
383
|
+
|
|
384
|
+
return xr.DataArray(
|
|
385
|
+
data=data,
|
|
386
|
+
dims=["channel", "time"],
|
|
387
|
+
coords={"time": time_steps, "channel": channel_names},
|
|
388
|
+
)
|
|
389
389
|
|
|
390
390
|
def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
|
|
391
|
-
"""Load EEG data from a local file
|
|
391
|
+
"""Load EEG data from a local BIDS-formatted file.
|
|
392
392
|
|
|
393
393
|
Parameters
|
|
394
394
|
----------
|
|
395
395
|
bids_file : str
|
|
396
|
-
Path to
|
|
396
|
+
Path to a BIDS-compliant EEG file (e.g., ``*_eeg.edf``, ``*_eeg.bdf``,
|
|
397
|
+
``*_eeg.vhdr``, ``*_eeg.set``).
|
|
397
398
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
xr.DataArray
|
|
402
|
+
EEG data with dimensions ``("channel", "time")``.
|
|
401
403
|
|
|
402
404
|
"""
|
|
403
405
|
bids_path = get_bids_path_from_fname(bids_file, verbose=False)
|
|
@@ -417,140 +419,25 @@ class EEGDash:
|
|
|
417
419
|
)
|
|
418
420
|
return eeg_xarray
|
|
419
421
|
|
|
420
|
-
def get_raw_extensions(
|
|
421
|
-
self, bids_file: str, bids_dataset: EEGBIDSDataset
|
|
422
|
-
) -> list[str]:
|
|
423
|
-
"""Helper to find paths to additional "sidecar" files that may be associated
|
|
424
|
-
with a given main data file in a BIDS dataset; paths are returned as relative to
|
|
425
|
-
the parent dataset path.
|
|
426
|
-
|
|
427
|
-
For example, if the input file is a .set file, this will return the relative path
|
|
428
|
-
to a corresponding .fdt file (if any).
|
|
429
|
-
"""
|
|
430
|
-
bids_file = Path(bids_file)
|
|
431
|
-
extensions = {
|
|
432
|
-
".set": [".set", ".fdt"], # eeglab
|
|
433
|
-
".edf": [".edf"], # european
|
|
434
|
-
".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
|
|
435
|
-
".bdf": [".bdf"], # biosemi
|
|
436
|
-
}
|
|
437
|
-
return [
|
|
438
|
-
str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
|
|
439
|
-
for suffix in extensions[bids_file.suffix]
|
|
440
|
-
if bids_file.with_suffix(suffix).exists()
|
|
441
|
-
]
|
|
442
|
-
|
|
443
|
-
def load_eeg_attrs_from_bids_file(
|
|
444
|
-
self, bids_dataset: EEGBIDSDataset, bids_file: str
|
|
445
|
-
) -> dict[str, Any]:
|
|
446
|
-
"""Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
|
|
447
|
-
|
|
448
|
-
Attributes are at least the ones defined in data_config attributes (set to None if missing),
|
|
449
|
-
but are typically a superset, and include, among others, the paths to relevant
|
|
450
|
-
meta-data files needed to load and interpret the file in question.
|
|
451
|
-
|
|
452
|
-
Parameters
|
|
453
|
-
----------
|
|
454
|
-
bids_dataset : EEGBIDSDataset
|
|
455
|
-
The BIDS dataset object containing the file.
|
|
456
|
-
bids_file : str
|
|
457
|
-
The path to the BIDS file within the dataset.
|
|
458
|
-
|
|
459
|
-
Returns
|
|
460
|
-
-------
|
|
461
|
-
dict:
|
|
462
|
-
A dictionary representing the metadata record for the given file. This is the
|
|
463
|
-
same format as the records stored in the database.
|
|
464
|
-
|
|
465
|
-
"""
|
|
466
|
-
if bids_file not in bids_dataset.files:
|
|
467
|
-
raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
|
|
468
|
-
|
|
469
|
-
# Initialize attrs with None values for all expected fields
|
|
470
|
-
attrs = {field: None for field in self.config["attributes"].keys()}
|
|
471
|
-
|
|
472
|
-
file = Path(bids_file).name
|
|
473
|
-
dsnumber = bids_dataset.dataset
|
|
474
|
-
# extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
|
|
475
|
-
openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
|
|
476
|
-
|
|
477
|
-
# Update with actual values where available
|
|
478
|
-
try:
|
|
479
|
-
participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
|
|
480
|
-
except Exception as e:
|
|
481
|
-
logger.error("Error getting participants_tsv: %s", str(e))
|
|
482
|
-
participants_tsv = None
|
|
483
|
-
|
|
484
|
-
try:
|
|
485
|
-
eeg_json = bids_dataset.eeg_json(bids_file)
|
|
486
|
-
except Exception as e:
|
|
487
|
-
logger.error("Error getting eeg_json: %s", str(e))
|
|
488
|
-
eeg_json = None
|
|
489
|
-
|
|
490
|
-
bids_dependencies_files = self.config["bids_dependencies_files"]
|
|
491
|
-
bidsdependencies = []
|
|
492
|
-
for extension in bids_dependencies_files:
|
|
493
|
-
try:
|
|
494
|
-
dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
|
|
495
|
-
dep_path = [
|
|
496
|
-
str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
|
|
497
|
-
]
|
|
498
|
-
bidsdependencies.extend(dep_path)
|
|
499
|
-
except Exception:
|
|
500
|
-
pass
|
|
501
|
-
|
|
502
|
-
bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
|
|
503
|
-
|
|
504
|
-
# Define field extraction functions with error handling
|
|
505
|
-
field_extractors = {
|
|
506
|
-
"data_name": lambda: f"{bids_dataset.dataset}_{file}",
|
|
507
|
-
"dataset": lambda: bids_dataset.dataset,
|
|
508
|
-
"bidspath": lambda: openneuro_path,
|
|
509
|
-
"subject": lambda: bids_dataset.get_bids_file_attribute(
|
|
510
|
-
"subject", bids_file
|
|
511
|
-
),
|
|
512
|
-
"task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
|
|
513
|
-
"session": lambda: bids_dataset.get_bids_file_attribute(
|
|
514
|
-
"session", bids_file
|
|
515
|
-
),
|
|
516
|
-
"run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
|
|
517
|
-
"modality": lambda: bids_dataset.get_bids_file_attribute(
|
|
518
|
-
"modality", bids_file
|
|
519
|
-
),
|
|
520
|
-
"sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
|
|
521
|
-
"sfreq", bids_file
|
|
522
|
-
),
|
|
523
|
-
"nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
|
|
524
|
-
"ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
|
|
525
|
-
"participant_tsv": lambda: participants_tsv,
|
|
526
|
-
"eeg_json": lambda: eeg_json,
|
|
527
|
-
"bidsdependencies": lambda: bidsdependencies,
|
|
528
|
-
}
|
|
529
|
-
|
|
530
|
-
# Dynamically populate attrs with error handling
|
|
531
|
-
for field, extractor in field_extractors.items():
|
|
532
|
-
try:
|
|
533
|
-
attrs[field] = extractor()
|
|
534
|
-
except Exception as e:
|
|
535
|
-
logger.error("Error extracting %s : %s", field, str(e))
|
|
536
|
-
attrs[field] = None
|
|
537
|
-
|
|
538
|
-
return attrs
|
|
539
|
-
|
|
540
422
|
def add_bids_dataset(
|
|
541
423
|
self, dataset: str, data_dir: str, overwrite: bool = True
|
|
542
424
|
) -> None:
|
|
543
|
-
"""
|
|
544
|
-
under the given dataset name.
|
|
425
|
+
"""Scan a local BIDS dataset and upsert records into MongoDB.
|
|
545
426
|
|
|
546
427
|
Parameters
|
|
547
428
|
----------
|
|
548
|
-
dataset : str
|
|
549
|
-
|
|
429
|
+
dataset : str
|
|
430
|
+
Dataset identifier (e.g., ``"ds002718"``).
|
|
550
431
|
data_dir : str
|
|
551
|
-
|
|
552
|
-
overwrite : bool
|
|
553
|
-
|
|
432
|
+
Path to the local BIDS dataset directory.
|
|
433
|
+
overwrite : bool, default True
|
|
434
|
+
If ``True``, update existing records when encountered; otherwise,
|
|
435
|
+
skip records that already exist.
|
|
436
|
+
|
|
437
|
+
Raises
|
|
438
|
+
------
|
|
439
|
+
ValueError
|
|
440
|
+
If called on a public client ``(is_public=True)``.
|
|
554
441
|
|
|
555
442
|
"""
|
|
556
443
|
if self.is_public:
|
|
@@ -565,7 +452,7 @@ class EEGDash:
|
|
|
565
452
|
dataset=dataset,
|
|
566
453
|
)
|
|
567
454
|
except Exception as e:
|
|
568
|
-
logger.error("Error creating bids dataset %s:
|
|
455
|
+
logger.error("Error creating bids dataset %s: %s", dataset, str(e))
|
|
569
456
|
raise e
|
|
570
457
|
requests = []
|
|
571
458
|
for bids_file in bids_dataset.get_files():
|
|
@@ -574,15 +461,13 @@ class EEGDash:
|
|
|
574
461
|
|
|
575
462
|
if self.exist({"data_name": data_id}):
|
|
576
463
|
if overwrite:
|
|
577
|
-
eeg_attrs =
|
|
464
|
+
eeg_attrs = load_eeg_attrs_from_bids_file(
|
|
578
465
|
bids_dataset, bids_file
|
|
579
466
|
)
|
|
580
|
-
requests.append(self.
|
|
467
|
+
requests.append(self._update_request(eeg_attrs))
|
|
581
468
|
else:
|
|
582
|
-
eeg_attrs =
|
|
583
|
-
|
|
584
|
-
)
|
|
585
|
-
requests.append(self.add_request(eeg_attrs))
|
|
469
|
+
eeg_attrs = load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
470
|
+
requests.append(self._add_request(eeg_attrs))
|
|
586
471
|
except Exception as e:
|
|
587
472
|
logger.error("Error adding record %s", bids_file)
|
|
588
473
|
logger.error(str(e))
|
|
@@ -598,22 +483,22 @@ class EEGDash:
|
|
|
598
483
|
logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
|
|
599
484
|
|
|
600
485
|
def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
|
|
601
|
-
"""
|
|
602
|
-
the `find()` method for details on the query format.
|
|
486
|
+
"""Download and return EEG data arrays for records matching a query.
|
|
603
487
|
|
|
604
488
|
Parameters
|
|
605
489
|
----------
|
|
606
490
|
query : dict
|
|
607
|
-
|
|
608
|
-
document that is used to match records in the MongoDB collection.
|
|
491
|
+
MongoDB query used to select records.
|
|
609
492
|
|
|
610
493
|
Returns
|
|
611
494
|
-------
|
|
612
|
-
|
|
495
|
+
list of xr.DataArray
|
|
496
|
+
EEG data for each matching record, with dimensions ``("channel", "time")``.
|
|
613
497
|
|
|
614
498
|
Notes
|
|
615
499
|
-----
|
|
616
|
-
Retrieval
|
|
500
|
+
Retrieval runs in parallel. Downloaded files are read and discarded
|
|
501
|
+
(no on-disk caching here).
|
|
617
502
|
|
|
618
503
|
"""
|
|
619
504
|
sessions = self.find(query)
|
|
@@ -623,12 +508,40 @@ class EEGDash:
|
|
|
623
508
|
results = Parallel(
|
|
624
509
|
n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
|
|
625
510
|
)(
|
|
626
|
-
delayed(self.load_eeg_data_from_s3)(self.
|
|
511
|
+
delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
|
|
627
512
|
for session in sessions
|
|
628
513
|
)
|
|
629
514
|
return results
|
|
630
515
|
|
|
631
|
-
def
|
|
516
|
+
def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
|
|
517
|
+
"""Build an S3 URI from a DB record or a relative path.
|
|
518
|
+
|
|
519
|
+
Parameters
|
|
520
|
+
----------
|
|
521
|
+
record : dict or str
|
|
522
|
+
Either a DB record containing a ``'bidspath'`` key, or a relative
|
|
523
|
+
path string under the OpenNeuro bucket.
|
|
524
|
+
|
|
525
|
+
Returns
|
|
526
|
+
-------
|
|
527
|
+
str
|
|
528
|
+
Fully qualified S3 URI.
|
|
529
|
+
|
|
530
|
+
Raises
|
|
531
|
+
------
|
|
532
|
+
ValueError
|
|
533
|
+
If a mapping is provided but ``'bidspath'`` is missing.
|
|
534
|
+
|
|
535
|
+
"""
|
|
536
|
+
if isinstance(record, str):
|
|
537
|
+
rel = record
|
|
538
|
+
else:
|
|
539
|
+
rel = record.get("bidspath")
|
|
540
|
+
if not rel:
|
|
541
|
+
raise ValueError("Record missing 'bidspath' for S3 path resolution")
|
|
542
|
+
return f"s3://openneuro.org/{rel}"
|
|
543
|
+
|
|
544
|
+
def _add_request(self, record: dict):
|
|
632
545
|
"""Internal helper method to create a MongoDB insertion request for a record."""
|
|
633
546
|
return InsertOne(record)
|
|
634
547
|
|
|
@@ -642,12 +555,19 @@ class EEGDash:
|
|
|
642
555
|
except:
|
|
643
556
|
logger.error("Error adding record: %s ", record["data_name"])
|
|
644
557
|
|
|
645
|
-
def
|
|
558
|
+
def _update_request(self, record: dict):
|
|
646
559
|
"""Internal helper method to create a MongoDB update request for a record."""
|
|
647
560
|
return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
|
|
648
561
|
|
|
649
562
|
def update(self, record: dict):
|
|
650
|
-
"""Update a single record in the MongoDB collection.
|
|
563
|
+
"""Update a single record in the MongoDB collection.
|
|
564
|
+
|
|
565
|
+
Parameters
|
|
566
|
+
----------
|
|
567
|
+
record : dict
|
|
568
|
+
Record content to set at the matching ``data_name``.
|
|
569
|
+
|
|
570
|
+
"""
|
|
651
571
|
try:
|
|
652
572
|
self.__collection.update_one(
|
|
653
573
|
{"data_name": record["data_name"]}, {"$set": record}
|
|
@@ -655,15 +575,33 @@ class EEGDash:
|
|
|
655
575
|
except: # silent failure
|
|
656
576
|
logger.error("Error updating record: %s", record["data_name"])
|
|
657
577
|
|
|
578
|
+
def exists(self, query: dict[str, Any]) -> bool:
|
|
579
|
+
"""Alias for :meth:`exist` provided for API clarity."""
|
|
580
|
+
return self.exist(query)
|
|
581
|
+
|
|
658
582
|
def remove_field(self, record, field):
|
|
659
|
-
"""Remove a specific field from a record in the MongoDB collection.
|
|
583
|
+
"""Remove a specific field from a record in the MongoDB collection.
|
|
584
|
+
|
|
585
|
+
Parameters
|
|
586
|
+
----------
|
|
587
|
+
record : dict
|
|
588
|
+
Record identifying object with ``data_name``.
|
|
589
|
+
field : str
|
|
590
|
+
Field name to remove.
|
|
591
|
+
|
|
592
|
+
"""
|
|
660
593
|
self.__collection.update_one(
|
|
661
594
|
{"data_name": record["data_name"]}, {"$unset": {field: 1}}
|
|
662
595
|
)
|
|
663
596
|
|
|
664
597
|
def remove_field_from_db(self, field):
|
|
665
|
-
"""
|
|
666
|
-
|
|
598
|
+
"""Remove a field from all records (destructive).
|
|
599
|
+
|
|
600
|
+
Parameters
|
|
601
|
+
----------
|
|
602
|
+
field : str
|
|
603
|
+
Field name to remove from every document.
|
|
604
|
+
|
|
667
605
|
"""
|
|
668
606
|
self.__collection.update_many({}, {"$unset": {field: 1}})
|
|
669
607
|
|
|
@@ -673,11 +611,13 @@ class EEGDash:
|
|
|
673
611
|
return self.__collection
|
|
674
612
|
|
|
675
613
|
def close(self):
|
|
676
|
-
"""
|
|
614
|
+
"""Backward-compatibility no-op; connections are managed globally.
|
|
615
|
+
|
|
616
|
+
Notes
|
|
617
|
+
-----
|
|
618
|
+
Connections are managed by :class:`MongoConnectionManager`. Use
|
|
619
|
+
:meth:`close_all_connections` to explicitly close all clients.
|
|
677
620
|
|
|
678
|
-
Note: Since MongoDB clients are now managed by a singleton,
|
|
679
|
-
this method no longer closes connections. Use close_all_connections()
|
|
680
|
-
class method to close all connections if needed.
|
|
681
621
|
"""
|
|
682
622
|
# Individual instances no longer close the shared client
|
|
683
623
|
pass
|
|
@@ -688,12 +628,78 @@ class EEGDash:
|
|
|
688
628
|
MongoConnectionManager.close_all()
|
|
689
629
|
|
|
690
630
|
def __del__(self):
|
|
691
|
-
"""
|
|
631
|
+
"""Destructor; no explicit action needed due to global connection manager."""
|
|
692
632
|
# No longer needed since we're using singleton pattern
|
|
693
633
|
pass
|
|
694
634
|
|
|
695
635
|
|
|
696
|
-
class EEGDashDataset(BaseConcatDataset):
|
|
636
|
+
class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitMeta):
|
|
637
|
+
"""Create a new EEGDashDataset from a given query or local BIDS dataset directory
|
|
638
|
+
and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
|
|
639
|
+
instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
|
|
640
|
+
|
|
641
|
+
Examples
|
|
642
|
+
--------
|
|
643
|
+
# Find by single subject
|
|
644
|
+
>>> ds = EEGDashDataset(dataset="ds005505", subject="NDARCA153NKE")
|
|
645
|
+
|
|
646
|
+
# Find by a list of subjects and a specific task
|
|
647
|
+
>>> subjects = ["NDARCA153NKE", "NDARXT792GY8"]
|
|
648
|
+
>>> ds = EEGDashDataset(dataset="ds005505", subject=subjects, task="RestingState")
|
|
649
|
+
|
|
650
|
+
# Use a raw MongoDB query for advanced filtering
|
|
651
|
+
>>> raw_query = {"dataset": "ds005505", "subject": {"$in": subjects}}
|
|
652
|
+
>>> ds = EEGDashDataset(query=raw_query)
|
|
653
|
+
|
|
654
|
+
Parameters
|
|
655
|
+
----------
|
|
656
|
+
cache_dir : str | Path
|
|
657
|
+
Directory where data are cached locally. If not specified, a default
|
|
658
|
+
cache directory under the user cache is used.
|
|
659
|
+
query : dict | None
|
|
660
|
+
Raw MongoDB query to filter records. If provided, it is merged with
|
|
661
|
+
keyword filtering arguments (see ``**kwargs``) using logical AND.
|
|
662
|
+
You must provide at least a ``dataset`` (either in ``query`` or
|
|
663
|
+
as a keyword argument). Only fields in ``ALLOWED_QUERY_FIELDS`` are
|
|
664
|
+
considered for filtering.
|
|
665
|
+
dataset : str
|
|
666
|
+
Dataset identifier (e.g., ``"ds002718"``). Required if ``query`` does
|
|
667
|
+
not already specify a dataset.
|
|
668
|
+
task : str | list[str]
|
|
669
|
+
Task name(s) to filter by (e.g., ``"RestingState"``).
|
|
670
|
+
subject : str | list[str]
|
|
671
|
+
Subject identifier(s) to filter by (e.g., ``"NDARCA153NKE"``).
|
|
672
|
+
session : str | list[str]
|
|
673
|
+
Session identifier(s) to filter by (e.g., ``"1"``).
|
|
674
|
+
run : str | list[str]
|
|
675
|
+
Run identifier(s) to filter by (e.g., ``"1"``).
|
|
676
|
+
description_fields : list[str]
|
|
677
|
+
Fields to extract from each record and include in dataset descriptions
|
|
678
|
+
(e.g., "subject", "session", "run", "task").
|
|
679
|
+
s3_bucket : str | None
|
|
680
|
+
Optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
|
|
681
|
+
default OpenNeuro bucket when downloading data files.
|
|
682
|
+
records : list[dict] | None
|
|
683
|
+
Pre-fetched metadata records. If provided, the dataset is constructed
|
|
684
|
+
directly from these records and no MongoDB query is performed.
|
|
685
|
+
download : bool, default True
|
|
686
|
+
If False, load from local BIDS files only. Local data are expected
|
|
687
|
+
under ``cache_dir / dataset``; no DB or S3 access is attempted.
|
|
688
|
+
n_jobs : int
|
|
689
|
+
Number of parallel jobs to use where applicable (-1 uses all cores).
|
|
690
|
+
eeg_dash_instance : EEGDash | None
|
|
691
|
+
Optional existing EEGDash client to reuse for DB queries. If None,
|
|
692
|
+
a new client is created on demand, not used in the case of no download.
|
|
693
|
+
**kwargs : dict
|
|
694
|
+
Additional keyword arguments serving two purposes:
|
|
695
|
+
|
|
696
|
+
- Filtering: any keys present in ``ALLOWED_QUERY_FIELDS`` are treated as
|
|
697
|
+
query filters (e.g., ``dataset``, ``subject``, ``task``, ...).
|
|
698
|
+
- Dataset options: remaining keys are forwarded to
|
|
699
|
+
``EEGDashBaseDataset``.
|
|
700
|
+
|
|
701
|
+
"""
|
|
702
|
+
|
|
697
703
|
def __init__(
|
|
698
704
|
self,
|
|
699
705
|
cache_dir: str | Path,
|
|
@@ -708,83 +714,64 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
708
714
|
"sex",
|
|
709
715
|
],
|
|
710
716
|
s3_bucket: str | None = None,
|
|
711
|
-
eeg_dash_instance=None,
|
|
712
717
|
records: list[dict] | None = None,
|
|
713
|
-
|
|
718
|
+
download: bool = True,
|
|
714
719
|
n_jobs: int = -1,
|
|
720
|
+
eeg_dash_instance: EEGDash | None = None,
|
|
715
721
|
**kwargs,
|
|
716
722
|
):
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
# Find by single subject
|
|
725
|
-
>>> ds = EEGDashDataset(dataset="ds005505", subject="NDARCA153NKE")
|
|
726
|
-
|
|
727
|
-
# Find by a list of subjects and a specific task
|
|
728
|
-
>>> subjects = ["NDARCA153NKE", "NDARXT792GY8"]
|
|
729
|
-
>>> ds = EEGDashDataset(dataset="ds005505", subject=subjects, task="RestingState")
|
|
730
|
-
|
|
731
|
-
# Use a raw MongoDB query for advanced filtering
|
|
732
|
-
>>> raw_query = {"dataset": "ds005505", "subject": {"$in": subjects}}
|
|
733
|
-
>>> ds = EEGDashDataset(query=raw_query)
|
|
723
|
+
# Parameters that don't need validation
|
|
724
|
+
_suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
|
|
725
|
+
self.s3_bucket = s3_bucket
|
|
726
|
+
self.records = records
|
|
727
|
+
self.download = download
|
|
728
|
+
self.n_jobs = n_jobs
|
|
729
|
+
self.eeg_dash_instance = eeg_dash_instance or EEGDash()
|
|
734
730
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
query : dict | None
|
|
738
|
-
A raw MongoDB query dictionary. If provided, keyword arguments for filtering are ignored.
|
|
739
|
-
**kwargs : dict
|
|
740
|
-
Keyword arguments for filtering (e.g., `subject="X"`, `task=["T1", "T2"]`) and/or
|
|
741
|
-
arguments to be passed to the EEGDashBaseDataset constructor (e.g., `subject=...`).
|
|
742
|
-
cache_dir : str
|
|
743
|
-
A directory where the dataset will be cached locally.
|
|
744
|
-
data_dir : str | None
|
|
745
|
-
Optionally a string specifying a local BIDS dataset directory from which to load the EEG data files. Exactly one
|
|
746
|
-
of query or data_dir must be provided.
|
|
747
|
-
dataset : str | None
|
|
748
|
-
If data_dir is given, a name for the dataset to be loaded.
|
|
749
|
-
description_fields : list[str]
|
|
750
|
-
A list of fields to be extracted from the dataset records
|
|
751
|
-
and included in the returned data description(s). Examples are typical
|
|
752
|
-
subject metadata fields such as "subject", "session", "run", "task", etc.;
|
|
753
|
-
see also data_config.description_fields for the default set of fields.
|
|
754
|
-
s3_bucket : str | None
|
|
755
|
-
An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
|
|
756
|
-
default OpenNeuro bucket for loading data files
|
|
757
|
-
records : list[dict] | None
|
|
758
|
-
Optional list of pre-fetched metadata records. If provided, the dataset is
|
|
759
|
-
constructed directly from these records without querying MongoDB.
|
|
760
|
-
offline_mode : bool
|
|
761
|
-
If True, do not attempt to query MongoDB at all. This is useful if you want to
|
|
762
|
-
work with a local cache only, or if you are offline.
|
|
763
|
-
n_jobs : int
|
|
764
|
-
The number of jobs to run in parallel (default is -1, meaning using all processors).
|
|
765
|
-
kwargs : dict
|
|
766
|
-
Additional keyword arguments to be passed to the EEGDashBaseDataset
|
|
767
|
-
constructor.
|
|
731
|
+
# Resolve a unified cache directory across code/tests/CI
|
|
732
|
+
self.cache_dir = Path(cache_dir or get_default_cache_dir())
|
|
768
733
|
|
|
769
|
-
"""
|
|
770
|
-
self.cache_dir = Path(cache_dir or platformdirs.user_cache_dir("EEGDash"))
|
|
771
734
|
if not self.cache_dir.exists():
|
|
772
735
|
warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
|
|
773
736
|
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
|
774
|
-
self.s3_bucket = s3_bucket
|
|
775
|
-
self.eeg_dash = eeg_dash_instance
|
|
776
737
|
|
|
777
738
|
# Separate query kwargs from other kwargs passed to the BaseDataset constructor
|
|
778
739
|
self.query = query or {}
|
|
779
740
|
self.query.update(
|
|
780
|
-
{k: v for k, v in kwargs.items() if k in
|
|
741
|
+
{k: v for k, v in kwargs.items() if k in ALLOWED_QUERY_FIELDS}
|
|
781
742
|
)
|
|
782
743
|
base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in self.query}
|
|
783
744
|
if "dataset" not in self.query:
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
745
|
+
# If explicit records are provided, infer dataset from records
|
|
746
|
+
if isinstance(records, list) and records and isinstance(records[0], dict):
|
|
747
|
+
inferred = records[0].get("dataset")
|
|
748
|
+
if inferred:
|
|
749
|
+
self.query["dataset"] = inferred
|
|
750
|
+
else:
|
|
751
|
+
raise ValueError("You must provide a 'dataset' argument")
|
|
752
|
+
else:
|
|
753
|
+
raise ValueError("You must provide a 'dataset' argument")
|
|
754
|
+
|
|
755
|
+
# Decide on a dataset subfolder name for cache isolation. If using
|
|
756
|
+
# challenge/preprocessed buckets (e.g., BDF, mini subsets), append
|
|
757
|
+
# informative suffixes to avoid overlapping with the original dataset.
|
|
758
|
+
dataset_folder = self.query["dataset"]
|
|
759
|
+
if self.s3_bucket:
|
|
760
|
+
suffixes: list[str] = []
|
|
761
|
+
bucket_lower = str(self.s3_bucket).lower()
|
|
762
|
+
if "bdf" in bucket_lower:
|
|
763
|
+
suffixes.append("bdf")
|
|
764
|
+
if "mini" in bucket_lower:
|
|
765
|
+
suffixes.append("mini")
|
|
766
|
+
if suffixes:
|
|
767
|
+
dataset_folder = f"{dataset_folder}-{'-'.join(suffixes)}"
|
|
768
|
+
|
|
769
|
+
self.data_dir = self.cache_dir / dataset_folder
|
|
770
|
+
|
|
771
|
+
if (
|
|
772
|
+
not _suppress_comp_warning
|
|
773
|
+
and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
|
|
774
|
+
):
|
|
788
775
|
warn(
|
|
789
776
|
"If you are not participating in the competition, you can ignore this warning!"
|
|
790
777
|
"\n\n"
|
|
@@ -794,76 +781,213 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
794
781
|
"IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n"
|
|
795
782
|
"and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
|
|
796
783
|
"to allow more people to participate.\n"
|
|
797
|
-
"\n"
|
|
784
|
+
"\n"
|
|
798
785
|
"If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
|
|
799
786
|
"\n",
|
|
800
787
|
UserWarning,
|
|
801
788
|
module="eegdash",
|
|
802
789
|
)
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
790
|
+
if records is not None:
|
|
791
|
+
self.records = records
|
|
792
|
+
datasets = [
|
|
793
|
+
EEGDashBaseDataset(
|
|
794
|
+
record,
|
|
795
|
+
self.cache_dir,
|
|
796
|
+
self.s3_bucket,
|
|
797
|
+
**base_dataset_kwargs,
|
|
798
|
+
)
|
|
799
|
+
for record in self.records
|
|
800
|
+
]
|
|
801
|
+
elif not download: # only assume local data is complete if not downloading
|
|
802
|
+
if not self.data_dir.exists():
|
|
803
|
+
raise ValueError(
|
|
804
|
+
f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
|
|
805
|
+
)
|
|
806
|
+
records = self._find_local_bids_records(self.data_dir, self.query)
|
|
807
|
+
# Try to enrich from local participants.tsv to restore requested fields
|
|
808
|
+
try:
|
|
809
|
+
bids_ds = EEGBIDSDataset(
|
|
810
|
+
data_dir=str(self.data_dir), dataset=self.query["dataset"]
|
|
811
|
+
) # type: ignore[index]
|
|
812
|
+
except Exception:
|
|
813
|
+
bids_ds = None
|
|
814
|
+
|
|
815
|
+
datasets = []
|
|
816
|
+
for record in records:
|
|
817
|
+
# Start with entity values from filename
|
|
818
|
+
desc: dict[str, Any] = {
|
|
819
|
+
k: record.get(k)
|
|
820
|
+
for k in ("subject", "session", "run", "task")
|
|
821
|
+
if record.get(k) is not None
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
if bids_ds is not None:
|
|
825
|
+
try:
|
|
826
|
+
rel_from_dataset = Path(record["bidspath"]).relative_to(
|
|
827
|
+
record["dataset"]
|
|
828
|
+
) # type: ignore[index]
|
|
829
|
+
local_file = (self.data_dir / rel_from_dataset).as_posix()
|
|
830
|
+
part_row = bids_ds.subject_participant_tsv(local_file)
|
|
831
|
+
desc = merge_participants_fields(
|
|
832
|
+
description=desc,
|
|
833
|
+
participants_row=part_row
|
|
834
|
+
if isinstance(part_row, dict)
|
|
835
|
+
else None,
|
|
836
|
+
description_fields=description_fields,
|
|
837
|
+
)
|
|
838
|
+
except Exception:
|
|
839
|
+
pass
|
|
807
840
|
|
|
808
|
-
|
|
809
|
-
if records is not None:
|
|
810
|
-
self.records = records
|
|
811
|
-
datasets = [
|
|
841
|
+
datasets.append(
|
|
812
842
|
EEGDashBaseDataset(
|
|
813
|
-
record,
|
|
814
|
-
self.cache_dir,
|
|
815
|
-
self.s3_bucket,
|
|
816
|
-
|
|
817
|
-
)
|
|
818
|
-
for record in self.records
|
|
819
|
-
]
|
|
820
|
-
elif offline_mode: # only assume local data is complete if in offline mode
|
|
821
|
-
if self.data_dir.exists():
|
|
822
|
-
# This path loads from a local directory and is not affected by DB query logic
|
|
823
|
-
datasets = self.load_bids_daxtaset(
|
|
824
|
-
dataset=self.query["dataset"],
|
|
825
|
-
data_dir=self.data_dir,
|
|
826
|
-
description_fields=description_fields,
|
|
827
|
-
s3_bucket=s3_bucket,
|
|
828
|
-
n_jobs=n_jobs,
|
|
843
|
+
record=record,
|
|
844
|
+
cache_dir=self.cache_dir,
|
|
845
|
+
s3_bucket=self.s3_bucket,
|
|
846
|
+
description=desc,
|
|
829
847
|
**base_dataset_kwargs,
|
|
830
848
|
)
|
|
831
|
-
else:
|
|
832
|
-
raise ValueError(
|
|
833
|
-
f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
|
|
834
|
-
)
|
|
835
|
-
elif self.query:
|
|
836
|
-
# This is the DB query path that we are improving
|
|
837
|
-
datasets = self._find_datasets(
|
|
838
|
-
query=self.eeg_dash._build_query_from_kwargs(**self.query),
|
|
839
|
-
description_fields=description_fields,
|
|
840
|
-
base_dataset_kwargs=base_dataset_kwargs,
|
|
841
|
-
)
|
|
842
|
-
# We only need filesystem if we need to access S3
|
|
843
|
-
self.filesystem = S3FileSystem(
|
|
844
|
-
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
845
|
-
)
|
|
846
|
-
else:
|
|
847
|
-
raise ValueError(
|
|
848
|
-
"You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
|
|
849
849
|
)
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
850
|
+
elif self.query:
|
|
851
|
+
# This is the DB query path that we are improving
|
|
852
|
+
datasets = self._find_datasets(
|
|
853
|
+
query=build_query_from_kwargs(**self.query),
|
|
854
|
+
description_fields=description_fields,
|
|
855
|
+
base_dataset_kwargs=base_dataset_kwargs,
|
|
856
|
+
)
|
|
857
|
+
# We only need filesystem if we need to access S3
|
|
858
|
+
self.filesystem = S3FileSystem(
|
|
859
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
860
|
+
)
|
|
861
|
+
else:
|
|
862
|
+
raise ValueError(
|
|
863
|
+
"You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
|
|
864
|
+
)
|
|
853
865
|
|
|
854
866
|
super().__init__(datasets)
|
|
855
867
|
|
|
856
|
-
def
|
|
857
|
-
|
|
858
|
-
|
|
868
|
+
def _find_local_bids_records(
|
|
869
|
+
self, dataset_root: Path, filters: dict[str, Any]
|
|
870
|
+
) -> list[dict]:
|
|
871
|
+
"""Discover local BIDS EEG files and build minimal records.
|
|
872
|
+
|
|
873
|
+
This helper enumerates EEG recordings under ``dataset_root`` via
|
|
874
|
+
``mne_bids.find_matching_paths`` and applies entity filters to produce a
|
|
875
|
+
list of records suitable for ``EEGDashBaseDataset``. No network access
|
|
876
|
+
is performed and files are not read.
|
|
877
|
+
|
|
878
|
+
Parameters
|
|
879
|
+
----------
|
|
880
|
+
dataset_root : Path
|
|
881
|
+
Local dataset directory. May be the plain dataset folder (e.g.,
|
|
882
|
+
``ds005509``) or a suffixed cache variant (e.g.,
|
|
883
|
+
``ds005509-bdf-mini``).
|
|
884
|
+
filters : dict of {str, Any}
|
|
885
|
+
Query filters. Must include ``'dataset'`` with the dataset id (without
|
|
886
|
+
local suffixes). May include BIDS entities ``'subject'``,
|
|
887
|
+
``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar
|
|
888
|
+
or a sequence of scalars.
|
|
889
|
+
|
|
890
|
+
Returns
|
|
891
|
+
-------
|
|
892
|
+
records : list of dict
|
|
893
|
+
One record per matched EEG file with at least:
|
|
894
|
+
|
|
895
|
+
- ``'data_name'``
|
|
896
|
+
- ``'dataset'`` (dataset id, without suffixes)
|
|
897
|
+
- ``'bidspath'`` (normalized to start with the dataset id)
|
|
898
|
+
- ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None)
|
|
899
|
+
- ``'bidsdependencies'`` (empty list)
|
|
900
|
+
- ``'modality'`` (``"eeg"``)
|
|
901
|
+
- ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal
|
|
902
|
+
defaults for offline usage)
|
|
903
|
+
|
|
904
|
+
Notes
|
|
905
|
+
-----
|
|
906
|
+
- Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
|
|
907
|
+
- ``bidspath`` is constructed as
|
|
908
|
+
``<dataset_id> / <relative_path_from_dataset_root>`` to ensure the
|
|
909
|
+
first path component is the dataset id (without local cache suffixes).
|
|
910
|
+
- Minimal defaults are set for ``sampling_frequency``, ``nchans``, and
|
|
911
|
+
``ntimes`` to satisfy dataset length requirements offline.
|
|
912
|
+
|
|
859
913
|
"""
|
|
914
|
+
dataset_id = filters["dataset"]
|
|
915
|
+
arg_map = {
|
|
916
|
+
"subjects": "subject",
|
|
917
|
+
"sessions": "session",
|
|
918
|
+
"tasks": "task",
|
|
919
|
+
"runs": "run",
|
|
920
|
+
}
|
|
921
|
+
matching_args: dict[str, list[str]] = {}
|
|
922
|
+
for finder_key, entity_key in arg_map.items():
|
|
923
|
+
entity_val = filters.get(entity_key)
|
|
924
|
+
if entity_val is None:
|
|
925
|
+
continue
|
|
926
|
+
if isinstance(entity_val, (list, tuple, set)):
|
|
927
|
+
entity_vals = list(entity_val)
|
|
928
|
+
if not entity_vals:
|
|
929
|
+
continue
|
|
930
|
+
matching_args[finder_key] = entity_vals
|
|
931
|
+
else:
|
|
932
|
+
matching_args[finder_key] = [entity_val]
|
|
933
|
+
|
|
934
|
+
matched_paths = find_matching_paths(
|
|
935
|
+
root=str(dataset_root),
|
|
936
|
+
datatypes=["eeg"],
|
|
937
|
+
suffixes=["eeg"],
|
|
938
|
+
ignore_json=True,
|
|
939
|
+
**matching_args,
|
|
940
|
+
)
|
|
941
|
+
records_out: list[dict] = []
|
|
942
|
+
|
|
943
|
+
for bids_path in matched_paths:
|
|
944
|
+
# Build bidspath as dataset_id / relative_path_from_dataset_root (POSIX)
|
|
945
|
+
rel_from_root = (
|
|
946
|
+
Path(bids_path.fpath)
|
|
947
|
+
.resolve()
|
|
948
|
+
.relative_to(Path(bids_path.root).resolve())
|
|
949
|
+
)
|
|
950
|
+
bidspath = f"{dataset_id}/{rel_from_root.as_posix()}"
|
|
951
|
+
|
|
952
|
+
rec = {
|
|
953
|
+
"data_name": f"{dataset_id}_{Path(bids_path.fpath).name}",
|
|
954
|
+
"dataset": dataset_id,
|
|
955
|
+
"bidspath": bidspath,
|
|
956
|
+
"subject": (bids_path.subject or None),
|
|
957
|
+
"session": (bids_path.session or None),
|
|
958
|
+
"task": (bids_path.task or None),
|
|
959
|
+
"run": (bids_path.run or None),
|
|
960
|
+
# minimal fields to satisfy BaseDataset from eegdash
|
|
961
|
+
"bidsdependencies": [], # not needed to just run.
|
|
962
|
+
"modality": "eeg",
|
|
963
|
+
# minimal numeric defaults for offline length calculation
|
|
964
|
+
"sampling_frequency": None,
|
|
965
|
+
"nchans": None,
|
|
966
|
+
"ntimes": None,
|
|
967
|
+
}
|
|
968
|
+
records_out.append(rec)
|
|
969
|
+
|
|
970
|
+
return records_out
|
|
971
|
+
|
|
972
|
+
def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
|
|
973
|
+
"""Recursively search for target_key in nested dicts/lists with normalized matching.
|
|
974
|
+
|
|
975
|
+
This makes lookups tolerant to naming differences like "p-factor" vs "p_factor".
|
|
976
|
+
Returns the first match or None.
|
|
977
|
+
"""
|
|
978
|
+
norm_target = normalize_key(target_key)
|
|
860
979
|
if isinstance(data, dict):
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
if
|
|
866
|
-
return
|
|
980
|
+
for k, v in data.items():
|
|
981
|
+
if normalize_key(k) == norm_target:
|
|
982
|
+
return v
|
|
983
|
+
res = self._find_key_in_nested_dict(v, target_key)
|
|
984
|
+
if res is not None:
|
|
985
|
+
return res
|
|
986
|
+
elif isinstance(data, list):
|
|
987
|
+
for item in data:
|
|
988
|
+
res = self._find_key_in_nested_dict(item, target_key)
|
|
989
|
+
if res is not None:
|
|
990
|
+
return res
|
|
867
991
|
return None
|
|
868
992
|
|
|
869
993
|
def _find_datasets(
|
|
@@ -892,15 +1016,23 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
892
1016
|
|
|
893
1017
|
"""
|
|
894
1018
|
datasets: list[EEGDashBaseDataset] = []
|
|
895
|
-
|
|
896
|
-
self.records = self.eeg_dash.find(query)
|
|
1019
|
+
self.records = self.eeg_dash_instance.find(query)
|
|
897
1020
|
|
|
898
1021
|
for record in self.records:
|
|
899
|
-
description = {}
|
|
1022
|
+
description: dict[str, Any] = {}
|
|
1023
|
+
# Requested fields first (normalized matching)
|
|
900
1024
|
for field in description_fields:
|
|
901
|
-
value = self.
|
|
1025
|
+
value = self._find_key_in_nested_dict(record, field)
|
|
902
1026
|
if value is not None:
|
|
903
1027
|
description[field] = value
|
|
1028
|
+
# Merge all participants.tsv columns generically
|
|
1029
|
+
part = self._find_key_in_nested_dict(record, "participant_tsv")
|
|
1030
|
+
if isinstance(part, dict):
|
|
1031
|
+
description = merge_participants_fields(
|
|
1032
|
+
description=description,
|
|
1033
|
+
participants_row=part,
|
|
1034
|
+
description_fields=description_fields,
|
|
1035
|
+
)
|
|
904
1036
|
datasets.append(
|
|
905
1037
|
EEGDashBaseDataset(
|
|
906
1038
|
record,
|
|
@@ -911,69 +1043,3 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
911
1043
|
)
|
|
912
1044
|
)
|
|
913
1045
|
return datasets
|
|
914
|
-
|
|
915
|
-
def load_bids_dataset(
|
|
916
|
-
self,
|
|
917
|
-
dataset: str,
|
|
918
|
-
data_dir: str | Path,
|
|
919
|
-
description_fields: list[str],
|
|
920
|
-
s3_bucket: str | None = None,
|
|
921
|
-
n_jobs: int = -1,
|
|
922
|
-
**kwargs,
|
|
923
|
-
):
|
|
924
|
-
"""Helper method to load a single local BIDS dataset and return it as a list of
|
|
925
|
-
EEGDashBaseDatasets (one for each recording in the dataset).
|
|
926
|
-
|
|
927
|
-
Parameters
|
|
928
|
-
----------
|
|
929
|
-
dataset : str
|
|
930
|
-
A name for the dataset to be loaded (e.g., "ds002718").
|
|
931
|
-
data_dir : str
|
|
932
|
-
The path to the local BIDS dataset directory.
|
|
933
|
-
description_fields : list[str]
|
|
934
|
-
A list of fields to be extracted from the dataset records
|
|
935
|
-
and included in the returned dataset description(s).
|
|
936
|
-
s3_bucket : str | None
|
|
937
|
-
The S3 bucket to upload the dataset files to (if any).
|
|
938
|
-
n_jobs : int
|
|
939
|
-
The number of jobs to run in parallel (default is -1, meaning using all processors).
|
|
940
|
-
|
|
941
|
-
"""
|
|
942
|
-
bids_dataset = EEGBIDSDataset(
|
|
943
|
-
data_dir=data_dir,
|
|
944
|
-
dataset=dataset,
|
|
945
|
-
)
|
|
946
|
-
datasets = Parallel(n_jobs=n_jobs, prefer="threads", verbose=1)(
|
|
947
|
-
delayed(self.get_base_dataset_from_bids_file)(
|
|
948
|
-
bids_dataset=bids_dataset,
|
|
949
|
-
bids_file=bids_file,
|
|
950
|
-
s3_bucket=s3_bucket,
|
|
951
|
-
description_fields=description_fields,
|
|
952
|
-
**kwargs,
|
|
953
|
-
)
|
|
954
|
-
for bids_file in bids_dataset.get_files()
|
|
955
|
-
)
|
|
956
|
-
return datasets
|
|
957
|
-
|
|
958
|
-
def get_base_dataset_from_bids_file(
|
|
959
|
-
self,
|
|
960
|
-
bids_dataset: "EEGBIDSDataset",
|
|
961
|
-
bids_file: str,
|
|
962
|
-
s3_bucket: str | None,
|
|
963
|
-
description_fields: list[str],
|
|
964
|
-
**kwargs,
|
|
965
|
-
) -> "EEGDashBaseDataset":
|
|
966
|
-
"""Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
|
|
967
|
-
record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
968
|
-
description = {}
|
|
969
|
-
for field in description_fields:
|
|
970
|
-
value = self.find_key_in_nested_dict(record, field)
|
|
971
|
-
if value is not None:
|
|
972
|
-
description[field] = value
|
|
973
|
-
return EEGDashBaseDataset(
|
|
974
|
-
record,
|
|
975
|
-
self.cache_dir,
|
|
976
|
-
s3_bucket,
|
|
977
|
-
description=description,
|
|
978
|
-
**kwargs,
|
|
979
|
-
)
|