eegdash 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +7 -3
- eegdash/api.py +690 -0
- eegdash/data_config.py +7 -1
- eegdash/data_utils.py +215 -118
- eegdash/dataset.py +60 -0
- eegdash/features/__init__.py +37 -9
- eegdash/features/datasets.py +57 -21
- eegdash/features/decorators.py +10 -2
- eegdash/features/extractors.py +20 -21
- eegdash/features/feature_bank/complexity.py +4 -0
- eegdash/features/feature_bank/csp.py +2 -2
- eegdash/features/feature_bank/dimensionality.py +7 -3
- eegdash/features/feature_bank/signal.py +29 -3
- eegdash/features/inspect.py +48 -0
- eegdash/features/serialization.py +2 -3
- eegdash/features/utils.py +1 -1
- eegdash/preprocessing.py +65 -0
- eegdash/utils.py +11 -0
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/METADATA +49 -6
- eegdash-0.2.0.dist-info/RECORD +27 -0
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/WHEEL +1 -1
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/licenses/LICENSE +1 -0
- eegdash/main.py +0 -416
- eegdash-0.1.0.dist-info/RECORD +0 -23
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/top_level.txt +0 -0
eegdash/data_config.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
config = {
|
|
2
2
|
"required_fields": ["data_name"],
|
|
3
|
+
# Default set of user-facing primary record attributes expected in the database. Records
|
|
4
|
+
# where any of these are missing will be loaded with the respective attribute set to None.
|
|
5
|
+
# Additional fields may be returned if they are present in the database, notably bidsdependencies.
|
|
3
6
|
"attributes": {
|
|
4
7
|
"data_name": "str",
|
|
5
8
|
"dataset": "str",
|
|
@@ -11,9 +14,12 @@ config = {
|
|
|
11
14
|
"sampling_frequency": "float",
|
|
12
15
|
"modality": "str",
|
|
13
16
|
"nchans": "int",
|
|
14
|
-
"ntimes": "int",
|
|
17
|
+
"ntimes": "int", # note: this is really the number of seconds in the data, rounded down
|
|
15
18
|
},
|
|
19
|
+
# queryable descriptive fields for a given recording
|
|
16
20
|
"description_fields": ["subject", "session", "run", "task", "age", "gender", "sex"],
|
|
21
|
+
# list of filenames that may be present in the BIDS dataset directory that are used
|
|
22
|
+
# to load and interpret a given BIDS recording.
|
|
17
23
|
"bids_dependencies_files": [
|
|
18
24
|
"dataset_description.json",
|
|
19
25
|
"participants.tsv",
|
eegdash/data_utils.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
import re
|
|
4
|
-
import sys
|
|
5
|
-
import tempfile
|
|
6
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import mne
|
|
9
9
|
import mne_bids
|
|
@@ -12,7 +12,7 @@ import pandas as pd
|
|
|
12
12
|
import s3fs
|
|
13
13
|
from bids import BIDSLayout
|
|
14
14
|
from joblib import Parallel, delayed
|
|
15
|
-
from mne._fiff.utils import
|
|
15
|
+
from mne._fiff.utils import _read_segments_file
|
|
16
16
|
from mne.io import BaseRaw
|
|
17
17
|
from mne_bids import (
|
|
18
18
|
BIDSPath,
|
|
@@ -20,51 +20,62 @@ from mne_bids import (
|
|
|
20
20
|
|
|
21
21
|
from braindecode.datasets import BaseDataset
|
|
22
22
|
|
|
23
|
+
logger = logging.getLogger("eegdash")
|
|
23
24
|
|
|
24
|
-
class EEGDashBaseDataset(BaseDataset):
|
|
25
|
-
"""Returns samples from an mne.io.Raw object along with a target.
|
|
26
25
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
`description` attribute.
|
|
26
|
+
class EEGDashBaseDataset(BaseDataset):
|
|
27
|
+
"""A single EEG recording hosted on AWS S3 and cached locally upon first access.
|
|
30
28
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
raw : mne.io.Raw
|
|
34
|
-
Continuous data.
|
|
35
|
-
description : dict | pandas.Series | None
|
|
36
|
-
Holds additional description about the continuous signal / subject.
|
|
37
|
-
target_name : str | tuple | None
|
|
38
|
-
Name(s) of the index in `description` that should be used to provide the
|
|
39
|
-
target (e.g., to be used in a prediction task later on).
|
|
40
|
-
transform : callable | None
|
|
41
|
-
On-the-fly transform applied to the example before it is returned.
|
|
29
|
+
This is a subclass of braindecode's BaseDataset, which can consequently be used in
|
|
30
|
+
conjunction with the preprocessing and training pipelines of braindecode.
|
|
42
31
|
"""
|
|
43
32
|
|
|
44
33
|
AWS_BUCKET = "s3://openneuro.org"
|
|
45
34
|
|
|
46
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
record: dict[str, Any],
|
|
38
|
+
cache_dir: str,
|
|
39
|
+
s3_bucket: str | None = None,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
"""Create a new EEGDashBaseDataset instance. Users do not usually need to call this
|
|
43
|
+
directly -- instead use the EEGDashDataset class to load a collection of these
|
|
44
|
+
recordings from a local BIDS folder or using a database query.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
record : dict
|
|
49
|
+
A fully resolved metadata record for the data to load.
|
|
50
|
+
cache_dir : str
|
|
51
|
+
A local directory where the data will be cached.
|
|
52
|
+
kwargs : dict
|
|
53
|
+
Additional keyword arguments to pass to the BaseDataset constructor.
|
|
54
|
+
|
|
55
|
+
"""
|
|
47
56
|
super().__init__(None, **kwargs)
|
|
48
57
|
self.record = record
|
|
49
58
|
self.cache_dir = Path(cache_dir)
|
|
50
59
|
bids_kwargs = self.get_raw_bids_args()
|
|
60
|
+
|
|
51
61
|
self.bidspath = BIDSPath(
|
|
52
62
|
root=self.cache_dir / record["dataset"],
|
|
53
63
|
datatype="eeg",
|
|
54
64
|
suffix="eeg",
|
|
55
65
|
**bids_kwargs,
|
|
56
66
|
)
|
|
67
|
+
self.s3_bucket = s3_bucket if s3_bucket else self.AWS_BUCKET
|
|
57
68
|
self.s3file = self.get_s3path(record["bidspath"])
|
|
58
69
|
self.filecache = self.cache_dir / record["bidspath"]
|
|
59
70
|
self.bids_dependencies = record["bidsdependencies"]
|
|
60
71
|
self._raw = None
|
|
61
|
-
# if os.path.exists(self.filecache):
|
|
62
|
-
# self.raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
|
|
63
72
|
|
|
64
|
-
def get_s3path(self, filepath):
|
|
65
|
-
|
|
73
|
+
def get_s3path(self, filepath: str) -> str:
|
|
74
|
+
"""Helper to form an AWS S3 URI for the given relative filepath."""
|
|
75
|
+
return f"{self.s3_bucket}/{filepath}"
|
|
66
76
|
|
|
67
|
-
def _download_s3(self):
|
|
77
|
+
def _download_s3(self) -> None:
|
|
78
|
+
"""Fetch the given data from its S3 location and cache it locally."""
|
|
68
79
|
self.filecache.parent.mkdir(parents=True, exist_ok=True)
|
|
69
80
|
filesystem = s3fs.S3FileSystem(
|
|
70
81
|
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
@@ -72,7 +83,10 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
72
83
|
filesystem.download(self.s3file, self.filecache)
|
|
73
84
|
self.filenames = [self.filecache]
|
|
74
85
|
|
|
75
|
-
def _download_dependencies(self):
|
|
86
|
+
def _download_dependencies(self) -> None:
|
|
87
|
+
"""Download all BIDS dependency files (metadata files, recording sidecar files)
|
|
88
|
+
from S3 and cache them locally.
|
|
89
|
+
"""
|
|
76
90
|
filesystem = s3fs.S3FileSystem(
|
|
77
91
|
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
78
92
|
)
|
|
@@ -83,11 +97,15 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
83
97
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
84
98
|
filesystem.download(s3path, filepath)
|
|
85
99
|
|
|
86
|
-
def get_raw_bids_args(self):
|
|
100
|
+
def get_raw_bids_args(self) -> dict[str, Any]:
|
|
101
|
+
"""Helper to restrict the metadata record to the fields needed to locate a BIDS
|
|
102
|
+
recording.
|
|
103
|
+
"""
|
|
87
104
|
desired_fields = ["subject", "session", "task", "run"]
|
|
88
105
|
return {k: self.record[k] for k in desired_fields if self.record[k]}
|
|
89
106
|
|
|
90
|
-
def check_and_get_raw(self):
|
|
107
|
+
def check_and_get_raw(self) -> None:
|
|
108
|
+
"""Download the S3 file and BIDS dependencies if not already cached."""
|
|
91
109
|
if not os.path.exists(self.filecache): # not preload
|
|
92
110
|
if self.bids_dependencies:
|
|
93
111
|
self._download_dependencies()
|
|
@@ -95,9 +113,10 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
95
113
|
if self._raw is None:
|
|
96
114
|
self._raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
|
|
97
115
|
|
|
98
|
-
|
|
99
|
-
# self.check_and_get_raw()
|
|
116
|
+
# === BaseDataset and PyTorch Dataset interface ===
|
|
100
117
|
|
|
118
|
+
def __getitem__(self, index):
|
|
119
|
+
"""Main function to access a sample from the dataset."""
|
|
101
120
|
X = self.raw[:, index][0]
|
|
102
121
|
y = None
|
|
103
122
|
if self.target_name is not None:
|
|
@@ -108,14 +127,20 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
108
127
|
X = self.transform(X)
|
|
109
128
|
return X, y
|
|
110
129
|
|
|
111
|
-
def __len__(self):
|
|
130
|
+
def __len__(self) -> int:
|
|
131
|
+
"""Return the number of samples in the dataset."""
|
|
112
132
|
if self._raw is None:
|
|
133
|
+
# FIXME: this is a bit strange and should definitely not change as a side effect
|
|
134
|
+
# of accessing the data (which it will, since ntimes is the actual length but rounded down)
|
|
113
135
|
return int(self.record["ntimes"] * self.record["sampling_frequency"])
|
|
114
136
|
else:
|
|
115
137
|
return len(self._raw)
|
|
116
138
|
|
|
117
139
|
@property
|
|
118
140
|
def raw(self):
|
|
141
|
+
"""Return the MNE Raw object for this recording. This will perform the actual
|
|
142
|
+
retrieval if not yet done so.
|
|
143
|
+
"""
|
|
119
144
|
if self._raw is None:
|
|
120
145
|
self.check_and_get_raw()
|
|
121
146
|
return self._raw
|
|
@@ -126,50 +151,44 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
126
151
|
|
|
127
152
|
|
|
128
153
|
class EEGDashBaseRaw(BaseRaw):
|
|
129
|
-
|
|
154
|
+
"""Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
|
|
155
|
+
(when _read_segment is called) and caches it locally. Currently for internal use.
|
|
130
156
|
|
|
131
157
|
Parameters
|
|
132
158
|
----------
|
|
133
159
|
input_fname : path-like
|
|
134
160
|
Path to the S3 file
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
161
|
+
metadata : dict
|
|
162
|
+
The metadata record for the recording (e.g., from the database).
|
|
163
|
+
preload : bool
|
|
164
|
+
Whether to pre-loaded the data before the first access.
|
|
165
|
+
cache_dir : str
|
|
166
|
+
Local path under which the data will be cached.
|
|
167
|
+
bids_dependencies : list
|
|
168
|
+
List of additional BIDS metadata files that should be downloaded and cached
|
|
169
|
+
alongside the main recording file.
|
|
170
|
+
verbose : str | int | None
|
|
171
|
+
Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
|
|
145
172
|
|
|
146
173
|
See Also
|
|
147
174
|
--------
|
|
148
175
|
mne.io.Raw : Documentation of attributes and methods.
|
|
149
176
|
|
|
150
|
-
Notes
|
|
151
|
-
-----
|
|
152
|
-
.. versionadded:: 0.11.0
|
|
153
177
|
"""
|
|
154
178
|
|
|
155
179
|
AWS_BUCKET = "s3://openneuro.org"
|
|
156
180
|
|
|
157
181
|
def __init__(
|
|
158
182
|
self,
|
|
159
|
-
input_fname,
|
|
160
|
-
metadata,
|
|
161
|
-
|
|
162
|
-
preload=False,
|
|
183
|
+
input_fname: str,
|
|
184
|
+
metadata: dict[str, Any],
|
|
185
|
+
preload: bool = False,
|
|
163
186
|
*,
|
|
164
|
-
cache_dir="./.eegdash_cache",
|
|
165
|
-
bids_dependencies: list = [],
|
|
166
|
-
|
|
167
|
-
montage_units="auto",
|
|
168
|
-
verbose=None,
|
|
187
|
+
cache_dir: str = "./.eegdash_cache",
|
|
188
|
+
bids_dependencies: list[str] = [],
|
|
189
|
+
verbose: Any = None,
|
|
169
190
|
):
|
|
170
|
-
"""
|
|
171
|
-
Get to work with S3 endpoint first, no caching
|
|
172
|
-
"""
|
|
191
|
+
"""Get to work with S3 endpoint first, no caching"""
|
|
173
192
|
# Create a simple RawArray
|
|
174
193
|
sfreq = metadata["sfreq"] # Sampling frequency
|
|
175
194
|
n_times = metadata["n_times"]
|
|
@@ -237,6 +256,20 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
237
256
|
|
|
238
257
|
|
|
239
258
|
class EEGBIDSDataset:
|
|
259
|
+
"""A one-stop shop interface to a local BIDS dataset containing EEG recordings.
|
|
260
|
+
|
|
261
|
+
This is mainly tailored to the needs of EEGDash application and is used to centralize
|
|
262
|
+
interactions with the BIDS dataset, such as parsing the metadata.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
data_dir : str | Path
|
|
267
|
+
The path to the local BIDS dataset directory.
|
|
268
|
+
dataset : str
|
|
269
|
+
A name for the dataset.
|
|
270
|
+
|
|
271
|
+
"""
|
|
272
|
+
|
|
240
273
|
ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
|
|
241
274
|
RAW_EXTENSIONS = {
|
|
242
275
|
".set": [".set", ".fdt"], # eeglab
|
|
@@ -270,19 +303,13 @@ class EEGBIDSDataset:
|
|
|
270
303
|
"Unable to construct EEG dataset. No EEG recordings found."
|
|
271
304
|
)
|
|
272
305
|
assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
# if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
|
|
277
|
-
# self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
|
|
278
|
-
# np.save(temp_dir / f'{dataset}_files.npy', self.files)
|
|
279
|
-
# else:
|
|
280
|
-
# self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
|
|
281
|
-
|
|
282
|
-
def check_eeg_dataset(self):
|
|
306
|
+
|
|
307
|
+
def check_eeg_dataset(self) -> bool:
|
|
308
|
+
"""Check if the dataset is EEG."""
|
|
283
309
|
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
|
|
284
310
|
|
|
285
|
-
def get_recordings(self, layout: BIDSLayout):
|
|
311
|
+
def get_recordings(self, layout: BIDSLayout) -> list[str]:
|
|
312
|
+
"""Get a list of all EEG recording files in the BIDS layout."""
|
|
286
313
|
files = []
|
|
287
314
|
for ext, exts in self.RAW_EXTENSIONS.items():
|
|
288
315
|
files = layout.get(extension=ext, return_type="filename")
|
|
@@ -290,11 +317,15 @@ class EEGBIDSDataset:
|
|
|
290
317
|
break
|
|
291
318
|
return files
|
|
292
319
|
|
|
293
|
-
def get_relative_bidspath(self, filename):
|
|
320
|
+
def get_relative_bidspath(self, filename: str) -> str:
|
|
321
|
+
"""Make the given file path relative to the BIDS directory."""
|
|
294
322
|
bids_parent_dir = self.bidsdir.parent.absolute()
|
|
295
323
|
return str(Path(filename).relative_to(bids_parent_dir))
|
|
296
324
|
|
|
297
|
-
def get_property_from_filename(self, property, filename):
|
|
325
|
+
def get_property_from_filename(self, property: str, filename: str) -> str:
|
|
326
|
+
"""Parse a property out of a BIDS-compliant filename. Returns an empty string
|
|
327
|
+
if not found.
|
|
328
|
+
"""
|
|
298
329
|
import platform
|
|
299
330
|
|
|
300
331
|
if platform.system() == "Windows":
|
|
@@ -303,25 +334,38 @@ class EEGBIDSDataset:
|
|
|
303
334
|
lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
|
|
304
335
|
return lookup.group(1) if lookup else ""
|
|
305
336
|
|
|
306
|
-
def merge_json_inheritance(self, json_files):
|
|
307
|
-
"""
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
and return a merged dictionary
|
|
337
|
+
def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
|
|
338
|
+
"""Internal helper to merge list of json files found by get_bids_file_inheritance,
|
|
339
|
+
expecting the order (from left to right) is from lowest
|
|
340
|
+
level to highest level, and return a merged dictionary
|
|
311
341
|
"""
|
|
312
342
|
json_files.reverse()
|
|
313
343
|
json_dict = {}
|
|
314
344
|
for f in json_files:
|
|
315
|
-
json_dict.update(json.load(open(f)))
|
|
345
|
+
json_dict.update(json.load(open(f))) # FIXME: should close file
|
|
316
346
|
return json_dict
|
|
317
347
|
|
|
318
|
-
def get_bids_file_inheritance(
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
348
|
+
def get_bids_file_inheritance(
|
|
349
|
+
self, path: str | Path, basename: str, extension: str
|
|
350
|
+
) -> list[Path]:
|
|
351
|
+
"""Get all file paths that apply to the basename file in the specified directory
|
|
352
|
+
and that end with the specified suffix, recursively searching parent directories
|
|
353
|
+
(following the BIDS inheritance principle in the order of lowest level first).
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
path : str | Path
|
|
358
|
+
The directory path to search for files.
|
|
359
|
+
basename : str
|
|
360
|
+
BIDS file basename without _eeg.set extension for example
|
|
361
|
+
extension : str
|
|
362
|
+
Only consider files that end with the specified suffix; e.g. channels.tsv
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
list[Path]
|
|
367
|
+
A list of file paths that match the given basename and extension.
|
|
368
|
+
|
|
325
369
|
"""
|
|
326
370
|
top_level_files = ["README", "dataset_description.json", "participants.tsv"]
|
|
327
371
|
bids_files = []
|
|
@@ -352,17 +396,25 @@ class EEGBIDSDataset:
|
|
|
352
396
|
)
|
|
353
397
|
return bids_files
|
|
354
398
|
|
|
355
|
-
def get_bids_metadata_files(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
399
|
+
def get_bids_metadata_files(
|
|
400
|
+
self, filepath: str | Path, metadata_file_extension: list[str]
|
|
401
|
+
) -> list[Path]:
|
|
402
|
+
"""Retrieve all metadata file paths that apply to a given data file path and that
|
|
403
|
+
end with a specific suffix (following the BIDS inheritance principle).
|
|
404
|
+
|
|
405
|
+
Parameters
|
|
406
|
+
----------
|
|
407
|
+
filepath: str | Path
|
|
408
|
+
The filepath to get the associated metadata files for.
|
|
409
|
+
metadata_file_extension : str
|
|
410
|
+
Consider only metadata files that end with the specified suffix,
|
|
411
|
+
e.g., channels.tsv or eeg.json
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
list[Path]:
|
|
416
|
+
A list of filepaths for all matching metadata files
|
|
363
417
|
|
|
364
|
-
Returns:
|
|
365
|
-
list: A list of filepaths for all the associated metadata files
|
|
366
418
|
"""
|
|
367
419
|
if isinstance(filepath, str):
|
|
368
420
|
filepath = Path(filepath)
|
|
@@ -376,7 +428,11 @@ class EEGBIDSDataset:
|
|
|
376
428
|
)
|
|
377
429
|
return meta_files
|
|
378
430
|
|
|
379
|
-
def scan_directory(self, directory, extension):
|
|
431
|
+
def scan_directory(self, directory: str, extension: str) -> list[Path]:
|
|
432
|
+
"""Return a list of file paths that end with the given extension in the specified
|
|
433
|
+
directory. Ignores certain special directories like .git, .datalad, derivatives,
|
|
434
|
+
and code.
|
|
435
|
+
"""
|
|
380
436
|
result_files = []
|
|
381
437
|
directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
|
|
382
438
|
with os.scandir(directory) as entries:
|
|
@@ -391,14 +447,35 @@ class EEGBIDSDataset:
|
|
|
391
447
|
return result_files
|
|
392
448
|
|
|
393
449
|
def get_files_with_extension_parallel(
|
|
394
|
-
self, directory, extension=".set", max_workers
|
|
395
|
-
):
|
|
450
|
+
self, directory: str, extension: str = ".set", max_workers: int = -1
|
|
451
|
+
) -> list[Path]:
|
|
452
|
+
"""Efficiently scan a directory and its subdirectories for files that end with
|
|
453
|
+
the given extension.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
directory : str
|
|
458
|
+
The root directory to scan for files.
|
|
459
|
+
extension : str
|
|
460
|
+
Only consider files that end with this suffix, e.g. '.set'.
|
|
461
|
+
max_workers : int
|
|
462
|
+
Optionally specify the maximum number of worker threads to use for parallel scanning.
|
|
463
|
+
Defaults to all available CPU cores if set to -1.
|
|
464
|
+
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
list[Path]:
|
|
468
|
+
A list of filepaths for all matching metadata files
|
|
469
|
+
|
|
470
|
+
"""
|
|
396
471
|
result_files = []
|
|
397
472
|
dirs_to_scan = [directory]
|
|
398
473
|
|
|
399
474
|
# Use joblib.Parallel and delayed to parallelize directory scanning
|
|
400
475
|
while dirs_to_scan:
|
|
401
|
-
|
|
476
|
+
logger.info(
|
|
477
|
+
f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
|
|
478
|
+
)
|
|
402
479
|
# Run the scan_directory function in parallel across directories
|
|
403
480
|
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
|
|
404
481
|
delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
|
|
@@ -412,12 +489,18 @@ class EEGBIDSDataset:
|
|
|
412
489
|
dirs_to_scan.append(path) # Queue up subdirectories to scan
|
|
413
490
|
else:
|
|
414
491
|
result_files.append(path) # Add files to the final result
|
|
415
|
-
|
|
492
|
+
logger.info(f"Found {len(result_files)} files.")
|
|
416
493
|
|
|
417
494
|
return result_files
|
|
418
495
|
|
|
419
|
-
def load_and_preprocess_raw(
|
|
420
|
-
|
|
496
|
+
def load_and_preprocess_raw(
|
|
497
|
+
self, raw_file: str, preprocess: bool = False
|
|
498
|
+
) -> np.ndarray:
|
|
499
|
+
"""Utility function to load a raw data file with MNE and apply some simple
|
|
500
|
+
(hardcoded) preprocessing and return as a numpy array. Not meant for purposes
|
|
501
|
+
other than testing or debugging.
|
|
502
|
+
"""
|
|
503
|
+
logger.info(f"Loading raw data from {raw_file}")
|
|
421
504
|
EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
|
|
422
505
|
|
|
423
506
|
if preprocess:
|
|
@@ -429,9 +512,6 @@ class EEGBIDSDataset:
|
|
|
429
512
|
sfreq = 128
|
|
430
513
|
if EEG.info["sfreq"] != sfreq:
|
|
431
514
|
EEG = EEG.resample(sfreq)
|
|
432
|
-
# # normalize data to zero mean and unit variance
|
|
433
|
-
# scalar = preprocessing.StandardScaler()
|
|
434
|
-
# mat_data = scalar.fit_transform(mat_data.T).T # scalar normalize for each feature and expects shape data x features
|
|
435
515
|
|
|
436
516
|
mat_data = EEG.get_data()
|
|
437
517
|
|
|
@@ -439,17 +519,22 @@ class EEGBIDSDataset:
|
|
|
439
519
|
raise ValueError("Expect raw data to be CxT dimension")
|
|
440
520
|
return mat_data
|
|
441
521
|
|
|
442
|
-
def get_files(self):
|
|
522
|
+
def get_files(self) -> list[Path]:
|
|
523
|
+
"""Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
|
|
443
524
|
return self.files
|
|
444
525
|
|
|
445
|
-
def resolve_bids_json(self, json_files: list):
|
|
446
|
-
"""
|
|
447
|
-
Resolve the BIDS JSON files and return a dictionary of the resolved values.
|
|
448
|
-
Args:
|
|
449
|
-
json_files (list): A list of JSON files to resolve in order of leaf level first
|
|
526
|
+
def resolve_bids_json(self, json_files: list[str]) -> dict:
|
|
527
|
+
"""Resolve the BIDS JSON files and return a dictionary of the resolved values.
|
|
450
528
|
|
|
451
|
-
|
|
529
|
+
Parameters
|
|
530
|
+
----------
|
|
531
|
+
json_files : list
|
|
532
|
+
A list of JSON file paths to resolve in order of leaf level first.
|
|
533
|
+
|
|
534
|
+
Returns
|
|
535
|
+
-------
|
|
452
536
|
dict: A dictionary of the resolved values.
|
|
537
|
+
|
|
453
538
|
"""
|
|
454
539
|
if len(json_files) == 0:
|
|
455
540
|
raise ValueError("No JSON files provided")
|
|
@@ -461,7 +546,10 @@ class EEGBIDSDataset:
|
|
|
461
546
|
json_dict.update(json.load(f))
|
|
462
547
|
return json_dict
|
|
463
548
|
|
|
464
|
-
def get_bids_file_attribute(self, attribute, data_filepath):
|
|
549
|
+
def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
|
|
550
|
+
"""Retrieve a specific attribute from the BIDS file metadata applicable
|
|
551
|
+
to the provided recording file path.
|
|
552
|
+
"""
|
|
465
553
|
entities = self.layout.parse_file_entities(data_filepath)
|
|
466
554
|
bidsfile = self.layout.get(**entities)[0]
|
|
467
555
|
attributes = bidsfile.get_entities(metadata="all")
|
|
@@ -478,27 +566,32 @@ class EEGBIDSDataset:
|
|
|
478
566
|
attribute_value = attributes.get(attribute_mapping.get(attribute), None)
|
|
479
567
|
return attribute_value
|
|
480
568
|
|
|
481
|
-
def channel_labels(self, data_filepath):
|
|
569
|
+
def channel_labels(self, data_filepath: str) -> list[str]:
|
|
570
|
+
"""Get a list of channel labels for the given data file path."""
|
|
482
571
|
channels_tsv = pd.read_csv(
|
|
483
572
|
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
484
573
|
)
|
|
485
574
|
return channels_tsv["name"].tolist()
|
|
486
575
|
|
|
487
|
-
def channel_types(self, data_filepath):
|
|
576
|
+
def channel_types(self, data_filepath: str) -> list[str]:
|
|
577
|
+
"""Get a list of channel types for the given data file path."""
|
|
488
578
|
channels_tsv = pd.read_csv(
|
|
489
579
|
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
490
580
|
)
|
|
491
581
|
return channels_tsv["type"].tolist()
|
|
492
582
|
|
|
493
|
-
def num_times(self, data_filepath):
|
|
583
|
+
def num_times(self, data_filepath: str) -> int:
|
|
584
|
+
"""Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
|
|
494
585
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
495
586
|
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
496
587
|
return int(
|
|
497
588
|
eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
|
|
498
589
|
)
|
|
499
590
|
|
|
500
|
-
def subject_participant_tsv(self, data_filepath):
|
|
501
|
-
"""Get
|
|
591
|
+
def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
592
|
+
"""Get BIDS participants.tsv record for the subject to which the given file
|
|
593
|
+
path corresponds, as a dictionary.
|
|
594
|
+
"""
|
|
502
595
|
participants_tsv = pd.read_csv(
|
|
503
596
|
self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
|
|
504
597
|
)
|
|
@@ -510,12 +603,16 @@ class EEGBIDSDataset:
|
|
|
510
603
|
subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
|
|
511
604
|
return participants_tsv.loc[subject].to_dict()
|
|
512
605
|
|
|
513
|
-
def eeg_json(self, data_filepath):
|
|
606
|
+
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
|
|
607
|
+
"""Get BIDS eeg.json metadata for the given data file path."""
|
|
514
608
|
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
515
609
|
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
516
610
|
return eeg_json_dict
|
|
517
611
|
|
|
518
|
-
def channel_tsv(self, data_filepath):
|
|
612
|
+
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
613
|
+
"""Get BIDS channels.tsv metadata for the given data file path, as a dictionary
|
|
614
|
+
of lists and/or single values.
|
|
615
|
+
"""
|
|
519
616
|
channels_tsv = pd.read_csv(
|
|
520
617
|
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
521
618
|
)
|
eegdash/dataset.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from .api import EEGDashDataset
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EEGChallengeDataset(EEGDashDataset):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
release: str = "R5",
|
|
8
|
+
cache_dir: str = ".eegdash_cache",
|
|
9
|
+
s3_bucket: str | None = "s3://nmdatasets/NeurIPS25/R5_L100",
|
|
10
|
+
**kwargs,
|
|
11
|
+
):
|
|
12
|
+
"""Create a new EEGDashDataset from a given query or local BIDS dataset directory
|
|
13
|
+
and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
|
|
14
|
+
instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
query : dict | None
|
|
19
|
+
Optionally a dictionary that specifies the query to be executed; see
|
|
20
|
+
EEGDash.find() for details on the query format.
|
|
21
|
+
data_dir : str | list[str] | None
|
|
22
|
+
Optionally a string or a list of strings specifying one or more local
|
|
23
|
+
BIDS dataset directories from which to load the EEG data files. Exactly one
|
|
24
|
+
of query or data_dir must be provided.
|
|
25
|
+
dataset : str | list[str] | None
|
|
26
|
+
If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
|
|
27
|
+
description_fields : list[str]
|
|
28
|
+
A list of fields to be extracted from the dataset records
|
|
29
|
+
and included in the returned data description(s). Examples are typical
|
|
30
|
+
subject metadata fields such as "subject", "session", "run", "task", etc.;
|
|
31
|
+
see also data_config.description_fields for the default set of fields.
|
|
32
|
+
cache_dir : str
|
|
33
|
+
A directory where the dataset will be cached locally.
|
|
34
|
+
s3_bucket : str | None
|
|
35
|
+
An optional S3 bucket URI to use instead of the
|
|
36
|
+
default OpenNeuro bucket for loading data files.
|
|
37
|
+
kwargs : dict
|
|
38
|
+
Additional keyword arguments to be passed to the EEGDashBaseDataset
|
|
39
|
+
constructor.
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
dsnumber_release_map = {
|
|
43
|
+
"R11": "ds005516",
|
|
44
|
+
"R10": "ds005515",
|
|
45
|
+
"R9": "ds005514",
|
|
46
|
+
"R8": "ds005512",
|
|
47
|
+
"R7": "ds005511",
|
|
48
|
+
"R6": "ds005510",
|
|
49
|
+
"R4": "ds005508",
|
|
50
|
+
"R5": "ds005509",
|
|
51
|
+
"R3": "ds005507",
|
|
52
|
+
"R2": "ds005506",
|
|
53
|
+
"R1": "ds005505",
|
|
54
|
+
}
|
|
55
|
+
super().__init__(
|
|
56
|
+
query={"dataset": dsnumber_release_map[release]},
|
|
57
|
+
cache_dir=cache_dir,
|
|
58
|
+
s3_bucket=s3_bucket,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|