eegdash 0.3.3.dev61__py3-none-any.whl → 0.5.0.dev180784713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eegdash/__init__.py +19 -6
- eegdash/api.py +336 -539
- eegdash/bids_eeg_metadata.py +495 -0
- eegdash/const.py +349 -0
- eegdash/dataset/__init__.py +28 -0
- eegdash/dataset/base.py +311 -0
- eegdash/dataset/bids_dataset.py +641 -0
- eegdash/dataset/dataset.py +692 -0
- eegdash/dataset/dataset_summary.csv +255 -0
- eegdash/dataset/registry.py +287 -0
- eegdash/downloader.py +197 -0
- eegdash/features/__init__.py +15 -13
- eegdash/features/datasets.py +329 -138
- eegdash/features/decorators.py +105 -13
- eegdash/features/extractors.py +233 -63
- eegdash/features/feature_bank/__init__.py +12 -12
- eegdash/features/feature_bank/complexity.py +22 -20
- eegdash/features/feature_bank/connectivity.py +27 -28
- eegdash/features/feature_bank/csp.py +3 -1
- eegdash/features/feature_bank/dimensionality.py +6 -6
- eegdash/features/feature_bank/signal.py +29 -30
- eegdash/features/feature_bank/spectral.py +40 -44
- eegdash/features/feature_bank/utils.py +8 -0
- eegdash/features/inspect.py +126 -15
- eegdash/features/serialization.py +58 -17
- eegdash/features/utils.py +90 -16
- eegdash/hbn/__init__.py +28 -0
- eegdash/hbn/preprocessing.py +105 -0
- eegdash/hbn/windows.py +428 -0
- eegdash/logging.py +54 -0
- eegdash/mongodb.py +55 -24
- eegdash/paths.py +52 -0
- eegdash/utils.py +29 -1
- eegdash-0.5.0.dev180784713.dist-info/METADATA +121 -0
- eegdash-0.5.0.dev180784713.dist-info/RECORD +38 -0
- eegdash-0.5.0.dev180784713.dist-info/licenses/LICENSE +29 -0
- eegdash/data_config.py +0 -34
- eegdash/data_utils.py +0 -687
- eegdash/dataset.py +0 -69
- eegdash/preprocessing.py +0 -63
- eegdash-0.3.3.dev61.dist-info/METADATA +0 -192
- eegdash-0.3.3.dev61.dist-info/RECORD +0 -28
- eegdash-0.3.3.dev61.dist-info/licenses/LICENSE +0 -23
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/WHEEL +0 -0
- {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/top_level.txt +0 -0
eegdash/data_utils.py
DELETED
|
@@ -1,687 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import re
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import mne
|
|
9
|
-
import numpy as np
|
|
10
|
-
import pandas as pd
|
|
11
|
-
import s3fs
|
|
12
|
-
from bids import BIDSLayout
|
|
13
|
-
from fsspec.callbacks import TqdmCallback
|
|
14
|
-
from joblib import Parallel, delayed
|
|
15
|
-
from mne._fiff.utils import _read_segments_file
|
|
16
|
-
from mne.io import BaseRaw
|
|
17
|
-
from mne_bids import BIDSPath
|
|
18
|
-
|
|
19
|
-
from braindecode.datasets import BaseDataset
|
|
20
|
-
|
|
21
|
-
logger = logging.getLogger("eegdash")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class EEGDashBaseDataset(BaseDataset):
|
|
25
|
-
"""A single EEG recording hosted on AWS S3 and cached locally upon first access.
|
|
26
|
-
|
|
27
|
-
This is a subclass of braindecode's BaseDataset, which can consequently be used in
|
|
28
|
-
conjunction with the preprocessing and training pipelines of braindecode.
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
_AWS_BUCKET = "s3://openneuro.org"
|
|
32
|
-
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
record: dict[str, Any],
|
|
36
|
-
cache_dir: str,
|
|
37
|
-
s3_bucket: str | None = None,
|
|
38
|
-
**kwargs,
|
|
39
|
-
):
|
|
40
|
-
"""Create a new EEGDashBaseDataset instance. Users do not usually need to call this
|
|
41
|
-
directly -- instead use the EEGDashDataset class to load a collection of these
|
|
42
|
-
recordings from a local BIDS folder or using a database query.
|
|
43
|
-
|
|
44
|
-
Parameters
|
|
45
|
-
----------
|
|
46
|
-
record : dict
|
|
47
|
-
A fully resolved metadata record for the data to load.
|
|
48
|
-
cache_dir : str
|
|
49
|
-
A local directory where the data will be cached.
|
|
50
|
-
kwargs : dict
|
|
51
|
-
Additional keyword arguments to pass to the BaseDataset constructor.
|
|
52
|
-
|
|
53
|
-
"""
|
|
54
|
-
super().__init__(None, **kwargs)
|
|
55
|
-
self.record = record
|
|
56
|
-
self.cache_dir = Path(cache_dir)
|
|
57
|
-
self.bids_kwargs = self.get_raw_bids_args()
|
|
58
|
-
|
|
59
|
-
if s3_bucket:
|
|
60
|
-
self.s3_bucket = s3_bucket
|
|
61
|
-
self.s3_open_neuro = False
|
|
62
|
-
else:
|
|
63
|
-
self.s3_bucket = self._AWS_BUCKET
|
|
64
|
-
self.s3_open_neuro = True
|
|
65
|
-
|
|
66
|
-
self.filecache = self.cache_dir / record["bidspath"]
|
|
67
|
-
|
|
68
|
-
self.bids_root = self.cache_dir / record["dataset"]
|
|
69
|
-
|
|
70
|
-
self.bidspath = BIDSPath(
|
|
71
|
-
root=self.bids_root,
|
|
72
|
-
datatype="eeg",
|
|
73
|
-
suffix="eeg",
|
|
74
|
-
**self.bids_kwargs,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
self.s3file = self.get_s3path(record["bidspath"])
|
|
78
|
-
self.bids_dependencies = record["bidsdependencies"]
|
|
79
|
-
# Temporary fix for BIDS dependencies path
|
|
80
|
-
# just to release to the competition
|
|
81
|
-
if not self.s3_open_neuro:
|
|
82
|
-
self.bids_dependencies_original = self.bids_dependencies
|
|
83
|
-
self.bids_dependencies = [
|
|
84
|
-
dep.split("/", 1)[1] for dep in self.bids_dependencies
|
|
85
|
-
]
|
|
86
|
-
|
|
87
|
-
self._raw = None
|
|
88
|
-
|
|
89
|
-
def get_s3path(self, filepath: str) -> str:
|
|
90
|
-
"""Helper to form an AWS S3 URI for the given relative filepath."""
|
|
91
|
-
return f"{self.s3_bucket}/{filepath}"
|
|
92
|
-
|
|
93
|
-
def _download_s3(self) -> None:
|
|
94
|
-
"""Download function that gets the raw EEG data from S3."""
|
|
95
|
-
filesystem = s3fs.S3FileSystem(
|
|
96
|
-
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
97
|
-
)
|
|
98
|
-
if not self.s3_open_neuro:
|
|
99
|
-
self.s3file = re.sub(r"(^|/)ds\d{6}/", r"\1", self.s3file, count=1)
|
|
100
|
-
|
|
101
|
-
self.filecache.parent.mkdir(parents=True, exist_ok=True)
|
|
102
|
-
info = filesystem.info(self.s3file)
|
|
103
|
-
size = info.get("size") or info.get("Size")
|
|
104
|
-
|
|
105
|
-
callback = TqdmCallback(
|
|
106
|
-
size=size,
|
|
107
|
-
tqdm_kwargs=dict(
|
|
108
|
-
desc=f"Downloading {Path(self.s3file).name}",
|
|
109
|
-
unit="B",
|
|
110
|
-
unit_scale=True,
|
|
111
|
-
unit_divisor=1024,
|
|
112
|
-
dynamic_ncols=True,
|
|
113
|
-
leave=True,
|
|
114
|
-
mininterval=0.2,
|
|
115
|
-
smoothing=0.1,
|
|
116
|
-
miniters=1,
|
|
117
|
-
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
|
|
118
|
-
"[{elapsed}<{remaining}, {rate_fmt}]",
|
|
119
|
-
),
|
|
120
|
-
)
|
|
121
|
-
filesystem.get(self.s3file, self.filecache, callback=callback)
|
|
122
|
-
|
|
123
|
-
self.filenames = [self.filecache]
|
|
124
|
-
|
|
125
|
-
def _download_dependencies(self) -> None:
|
|
126
|
-
"""Download all BIDS dependency files (metadata files, recording sidecar files)
|
|
127
|
-
from S3 and cache them locally.
|
|
128
|
-
"""
|
|
129
|
-
filesystem = s3fs.S3FileSystem(
|
|
130
|
-
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
131
|
-
)
|
|
132
|
-
for i, dep in enumerate(self.bids_dependencies):
|
|
133
|
-
s3path = self.get_s3path(dep)
|
|
134
|
-
if not self.s3_open_neuro:
|
|
135
|
-
dep = self.bids_dependencies_original[i]
|
|
136
|
-
|
|
137
|
-
filepath = self.cache_dir / dep
|
|
138
|
-
# here, we download the dependency and it is fine
|
|
139
|
-
# in the case of the competition.
|
|
140
|
-
if not filepath.exists():
|
|
141
|
-
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
142
|
-
info = filesystem.info(s3path)
|
|
143
|
-
size = info.get("size") or info.get("Size")
|
|
144
|
-
|
|
145
|
-
callback = TqdmCallback(
|
|
146
|
-
size=size,
|
|
147
|
-
tqdm_kwargs=dict(
|
|
148
|
-
desc=f"Downloading {Path(s3path).name}",
|
|
149
|
-
unit="B",
|
|
150
|
-
unit_scale=True,
|
|
151
|
-
unit_divisor=1024,
|
|
152
|
-
dynamic_ncols=True,
|
|
153
|
-
leave=True,
|
|
154
|
-
mininterval=0.2,
|
|
155
|
-
smoothing=0.1,
|
|
156
|
-
miniters=1,
|
|
157
|
-
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
|
|
158
|
-
"[{elapsed}<{remaining}, {rate_fmt}]",
|
|
159
|
-
),
|
|
160
|
-
)
|
|
161
|
-
filesystem.get(s3path, filepath, callback=callback)
|
|
162
|
-
|
|
163
|
-
def get_raw_bids_args(self) -> dict[str, Any]:
|
|
164
|
-
"""Helper to restrict the metadata record to the fields needed to locate a BIDS
|
|
165
|
-
recording.
|
|
166
|
-
"""
|
|
167
|
-
desired_fields = ["subject", "session", "task", "run"]
|
|
168
|
-
return {k: self.record[k] for k in desired_fields if self.record[k]}
|
|
169
|
-
|
|
170
|
-
def check_and_get_raw(self) -> None:
|
|
171
|
-
"""Download the S3 file and BIDS dependencies if not already cached."""
|
|
172
|
-
if not os.path.exists(self.filecache): # not preload
|
|
173
|
-
if self.bids_dependencies:
|
|
174
|
-
self._download_dependencies()
|
|
175
|
-
self._download_s3()
|
|
176
|
-
if self._raw is None:
|
|
177
|
-
self._raw = mne.io.read_raw(fname=self.bidspath, verbose=False)
|
|
178
|
-
|
|
179
|
-
# === BaseDataset and PyTorch Dataset interface ===
|
|
180
|
-
|
|
181
|
-
def __getitem__(self, index):
|
|
182
|
-
"""Main function to access a sample from the dataset."""
|
|
183
|
-
X = self.raw[:, index][0]
|
|
184
|
-
y = None
|
|
185
|
-
if self.target_name is not None:
|
|
186
|
-
y = self.description[self.target_name]
|
|
187
|
-
if isinstance(y, pd.Series):
|
|
188
|
-
y = y.to_list()
|
|
189
|
-
if self.transform is not None:
|
|
190
|
-
X = self.transform(X)
|
|
191
|
-
return X, y
|
|
192
|
-
|
|
193
|
-
def __len__(self) -> int:
|
|
194
|
-
"""Return the number of samples in the dataset."""
|
|
195
|
-
if self._raw is None:
|
|
196
|
-
# FIXME: this is a bit strange and should definitely not change as a side effect
|
|
197
|
-
# of accessing the data (which it will, since ntimes is the actual length but rounded down)
|
|
198
|
-
return int(self.record["ntimes"] * self.record["sampling_frequency"])
|
|
199
|
-
else:
|
|
200
|
-
return len(self._raw)
|
|
201
|
-
|
|
202
|
-
@property
|
|
203
|
-
def raw(self):
|
|
204
|
-
"""Return the MNE Raw object for this recording. This will perform the actual
|
|
205
|
-
retrieval if not yet done so.
|
|
206
|
-
"""
|
|
207
|
-
if self._raw is None:
|
|
208
|
-
self.check_and_get_raw()
|
|
209
|
-
return self._raw
|
|
210
|
-
|
|
211
|
-
@raw.setter
|
|
212
|
-
def raw(self, raw):
|
|
213
|
-
self._raw = raw
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
class EEGDashBaseRaw(BaseRaw):
|
|
217
|
-
"""Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
|
|
218
|
-
(when _read_segment is called) and caches it locally. Currently for internal use.
|
|
219
|
-
|
|
220
|
-
Parameters
|
|
221
|
-
----------
|
|
222
|
-
input_fname : path-like
|
|
223
|
-
Path to the S3 file
|
|
224
|
-
metadata : dict
|
|
225
|
-
The metadata record for the recording (e.g., from the database).
|
|
226
|
-
preload : bool
|
|
227
|
-
Whether to pre-loaded the data before the first access.
|
|
228
|
-
cache_dir : str
|
|
229
|
-
Local path under which the data will be cached.
|
|
230
|
-
bids_dependencies : list
|
|
231
|
-
List of additional BIDS metadata files that should be downloaded and cached
|
|
232
|
-
alongside the main recording file.
|
|
233
|
-
verbose : str | int | None
|
|
234
|
-
Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
|
|
235
|
-
|
|
236
|
-
See Also
|
|
237
|
-
--------
|
|
238
|
-
mne.io.Raw : Documentation of attributes and methods.
|
|
239
|
-
|
|
240
|
-
"""
|
|
241
|
-
|
|
242
|
-
_AWS_BUCKET = "s3://openneuro.org"
|
|
243
|
-
|
|
244
|
-
def __init__(
|
|
245
|
-
self,
|
|
246
|
-
input_fname: str,
|
|
247
|
-
metadata: dict[str, Any],
|
|
248
|
-
preload: bool = False,
|
|
249
|
-
*,
|
|
250
|
-
cache_dir: str = "~/eegdash_cache",
|
|
251
|
-
bids_dependencies: list[str] = [],
|
|
252
|
-
verbose: Any = None,
|
|
253
|
-
):
|
|
254
|
-
"""Get to work with S3 endpoint first, no caching"""
|
|
255
|
-
# Create a simple RawArray
|
|
256
|
-
sfreq = metadata["sfreq"] # Sampling frequency
|
|
257
|
-
n_times = metadata["n_times"]
|
|
258
|
-
ch_names = metadata["ch_names"]
|
|
259
|
-
ch_types = []
|
|
260
|
-
for ch in metadata["ch_types"]:
|
|
261
|
-
chtype = ch.lower()
|
|
262
|
-
if chtype == "heog" or chtype == "veog":
|
|
263
|
-
chtype = "eog"
|
|
264
|
-
ch_types.append(chtype)
|
|
265
|
-
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
266
|
-
self.s3file = self.get_s3path(input_fname)
|
|
267
|
-
self.cache_dir = Path(cache_dir)
|
|
268
|
-
self.filecache = self.cache_dir / input_fname
|
|
269
|
-
self.bids_dependencies = bids_dependencies
|
|
270
|
-
|
|
271
|
-
if preload and not os.path.exists(self.filecache):
|
|
272
|
-
self._download_s3()
|
|
273
|
-
preload = self.filecache
|
|
274
|
-
|
|
275
|
-
super().__init__(
|
|
276
|
-
info,
|
|
277
|
-
preload,
|
|
278
|
-
last_samps=[n_times - 1],
|
|
279
|
-
orig_format="single",
|
|
280
|
-
verbose=verbose,
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
def get_s3path(self, filepath):
|
|
284
|
-
print(f"Getting S3 path for {filepath}")
|
|
285
|
-
return f"{self._AWS_BUCKET}/{filepath}"
|
|
286
|
-
|
|
287
|
-
def _download_s3(self) -> None:
|
|
288
|
-
self.filecache.parent.mkdir(parents=True, exist_ok=True)
|
|
289
|
-
filesystem = s3fs.S3FileSystem(
|
|
290
|
-
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
291
|
-
)
|
|
292
|
-
filesystem.download(self.s3file, self.filecache)
|
|
293
|
-
self.filenames = [self.filecache]
|
|
294
|
-
|
|
295
|
-
def _download_dependencies(self):
|
|
296
|
-
filesystem = s3fs.S3FileSystem(
|
|
297
|
-
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
298
|
-
)
|
|
299
|
-
for dep in self.bids_dependencies:
|
|
300
|
-
s3path = self.get_s3path(dep)
|
|
301
|
-
filepath = self.cache_dir / dep
|
|
302
|
-
if not filepath.exists():
|
|
303
|
-
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
304
|
-
filesystem.download(s3path, filepath)
|
|
305
|
-
|
|
306
|
-
def _read_segment(
|
|
307
|
-
self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
|
|
308
|
-
):
|
|
309
|
-
if not os.path.exists(self.filecache): # not preload
|
|
310
|
-
if self.bids_dependencies:
|
|
311
|
-
self._download_dependencies()
|
|
312
|
-
self._download_s3()
|
|
313
|
-
else: # not preload and file is not cached
|
|
314
|
-
self.filenames = [self.filecache]
|
|
315
|
-
return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
|
|
316
|
-
|
|
317
|
-
def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
|
|
318
|
-
"""Read a chunk of data from the file."""
|
|
319
|
-
_read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
class EEGBIDSDataset:
|
|
323
|
-
"""A one-stop shop interface to a local BIDS dataset containing EEG recordings.
|
|
324
|
-
|
|
325
|
-
This is mainly tailored to the needs of EEGDash application and is used to centralize
|
|
326
|
-
interactions with the BIDS dataset, such as parsing the metadata.
|
|
327
|
-
|
|
328
|
-
Parameters
|
|
329
|
-
----------
|
|
330
|
-
data_dir : str | Path
|
|
331
|
-
The path to the local BIDS dataset directory.
|
|
332
|
-
dataset : str
|
|
333
|
-
A name for the dataset.
|
|
334
|
-
|
|
335
|
-
"""
|
|
336
|
-
|
|
337
|
-
ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
|
|
338
|
-
RAW_EXTENSIONS = {
|
|
339
|
-
".set": [".set", ".fdt"], # eeglab
|
|
340
|
-
".edf": [".edf"], # european
|
|
341
|
-
".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
|
|
342
|
-
".bdf": [".bdf"], # biosemi
|
|
343
|
-
}
|
|
344
|
-
METADATA_FILE_EXTENSIONS = [
|
|
345
|
-
"eeg.json",
|
|
346
|
-
"channels.tsv",
|
|
347
|
-
"electrodes.tsv",
|
|
348
|
-
"events.tsv",
|
|
349
|
-
"events.json",
|
|
350
|
-
]
|
|
351
|
-
|
|
352
|
-
def __init__(
|
|
353
|
-
self,
|
|
354
|
-
data_dir=None, # location of bids dataset
|
|
355
|
-
dataset="", # dataset name
|
|
356
|
-
):
|
|
357
|
-
if data_dir is None or not os.path.exists(data_dir):
|
|
358
|
-
raise ValueError("data_dir must be specified and must exist")
|
|
359
|
-
self.bidsdir = Path(data_dir)
|
|
360
|
-
self.dataset = dataset
|
|
361
|
-
assert str(self.bidsdir).endswith(self.dataset)
|
|
362
|
-
self.layout = BIDSLayout(data_dir)
|
|
363
|
-
|
|
364
|
-
# get all recording files in the bids directory
|
|
365
|
-
self.files = self.get_recordings(self.layout)
|
|
366
|
-
assert len(self.files) > 0, ValueError(
|
|
367
|
-
"Unable to construct EEG dataset. No EEG recordings found."
|
|
368
|
-
)
|
|
369
|
-
assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
|
|
370
|
-
|
|
371
|
-
def check_eeg_dataset(self) -> bool:
|
|
372
|
-
"""Check if the dataset is EEG."""
|
|
373
|
-
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
|
|
374
|
-
|
|
375
|
-
def get_recordings(self, layout: BIDSLayout) -> list[str]:
|
|
376
|
-
"""Get a list of all EEG recording files in the BIDS layout."""
|
|
377
|
-
files = []
|
|
378
|
-
for ext, exts in self.RAW_EXTENSIONS.items():
|
|
379
|
-
files = layout.get(extension=ext, return_type="filename")
|
|
380
|
-
if files:
|
|
381
|
-
break
|
|
382
|
-
return files
|
|
383
|
-
|
|
384
|
-
def get_relative_bidspath(self, filename: str) -> str:
|
|
385
|
-
"""Make the given file path relative to the BIDS directory."""
|
|
386
|
-
bids_parent_dir = self.bidsdir.parent.absolute()
|
|
387
|
-
return str(Path(filename).relative_to(bids_parent_dir))
|
|
388
|
-
|
|
389
|
-
def get_property_from_filename(self, property: str, filename: str) -> str:
|
|
390
|
-
"""Parse a property out of a BIDS-compliant filename. Returns an empty string
|
|
391
|
-
if not found.
|
|
392
|
-
"""
|
|
393
|
-
import platform
|
|
394
|
-
|
|
395
|
-
if platform.system() == "Windows":
|
|
396
|
-
lookup = re.search(rf"{property}-(.*?)[_\\]", filename)
|
|
397
|
-
else:
|
|
398
|
-
lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
|
|
399
|
-
return lookup.group(1) if lookup else ""
|
|
400
|
-
|
|
401
|
-
def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
|
|
402
|
-
"""Internal helper to merge list of json files found by get_bids_file_inheritance,
|
|
403
|
-
expecting the order (from left to right) is from lowest
|
|
404
|
-
level to highest level, and return a merged dictionary
|
|
405
|
-
"""
|
|
406
|
-
json_files.reverse()
|
|
407
|
-
json_dict = {}
|
|
408
|
-
for f in json_files:
|
|
409
|
-
json_dict.update(json.load(open(f))) # FIXME: should close file
|
|
410
|
-
return json_dict
|
|
411
|
-
|
|
412
|
-
def get_bids_file_inheritance(
|
|
413
|
-
self, path: str | Path, basename: str, extension: str
|
|
414
|
-
) -> list[Path]:
|
|
415
|
-
"""Get all file paths that apply to the basename file in the specified directory
|
|
416
|
-
and that end with the specified suffix, recursively searching parent directories
|
|
417
|
-
(following the BIDS inheritance principle in the order of lowest level first).
|
|
418
|
-
|
|
419
|
-
Parameters
|
|
420
|
-
----------
|
|
421
|
-
path : str | Path
|
|
422
|
-
The directory path to search for files.
|
|
423
|
-
basename : str
|
|
424
|
-
BIDS file basename without _eeg.set extension for example
|
|
425
|
-
extension : str
|
|
426
|
-
Only consider files that end with the specified suffix; e.g. channels.tsv
|
|
427
|
-
|
|
428
|
-
Returns
|
|
429
|
-
-------
|
|
430
|
-
list[Path]
|
|
431
|
-
A list of file paths that match the given basename and extension.
|
|
432
|
-
|
|
433
|
-
"""
|
|
434
|
-
top_level_files = ["README", "dataset_description.json", "participants.tsv"]
|
|
435
|
-
bids_files = []
|
|
436
|
-
|
|
437
|
-
# check if path is str object
|
|
438
|
-
if isinstance(path, str):
|
|
439
|
-
path = Path(path)
|
|
440
|
-
if not path.exists:
|
|
441
|
-
raise ValueError("path {path} does not exist")
|
|
442
|
-
|
|
443
|
-
# check if file is in current path
|
|
444
|
-
for file in os.listdir(path):
|
|
445
|
-
# target_file = path / f"{cur_file_basename}_{extension}"
|
|
446
|
-
if os.path.isfile(path / file):
|
|
447
|
-
# check if file has extension extension
|
|
448
|
-
# check if file basename has extension
|
|
449
|
-
if file.endswith(extension):
|
|
450
|
-
filepath = path / file
|
|
451
|
-
bids_files.append(filepath)
|
|
452
|
-
|
|
453
|
-
# check if file is in top level directory
|
|
454
|
-
if any(file in os.listdir(path) for file in top_level_files):
|
|
455
|
-
return bids_files
|
|
456
|
-
else:
|
|
457
|
-
# call get_bids_file_inheritance recursively with parent directory
|
|
458
|
-
bids_files.extend(
|
|
459
|
-
self.get_bids_file_inheritance(path.parent, basename, extension)
|
|
460
|
-
)
|
|
461
|
-
return bids_files
|
|
462
|
-
|
|
463
|
-
def get_bids_metadata_files(
|
|
464
|
-
self, filepath: str | Path, metadata_file_extension: list[str]
|
|
465
|
-
) -> list[Path]:
|
|
466
|
-
"""Retrieve all metadata file paths that apply to a given data file path and that
|
|
467
|
-
end with a specific suffix (following the BIDS inheritance principle).
|
|
468
|
-
|
|
469
|
-
Parameters
|
|
470
|
-
----------
|
|
471
|
-
filepath: str | Path
|
|
472
|
-
The filepath to get the associated metadata files for.
|
|
473
|
-
metadata_file_extension : str
|
|
474
|
-
Consider only metadata files that end with the specified suffix,
|
|
475
|
-
e.g., channels.tsv or eeg.json
|
|
476
|
-
|
|
477
|
-
Returns
|
|
478
|
-
-------
|
|
479
|
-
list[Path]:
|
|
480
|
-
A list of filepaths for all matching metadata files
|
|
481
|
-
|
|
482
|
-
"""
|
|
483
|
-
if isinstance(filepath, str):
|
|
484
|
-
filepath = Path(filepath)
|
|
485
|
-
if not filepath.exists:
|
|
486
|
-
raise ValueError("filepath {filepath} does not exist")
|
|
487
|
-
path, filename = os.path.split(filepath)
|
|
488
|
-
basename = filename[: filename.rfind("_")]
|
|
489
|
-
# metadata files
|
|
490
|
-
meta_files = self.get_bids_file_inheritance(
|
|
491
|
-
path, basename, metadata_file_extension
|
|
492
|
-
)
|
|
493
|
-
return meta_files
|
|
494
|
-
|
|
495
|
-
def scan_directory(self, directory: str, extension: str) -> list[Path]:
|
|
496
|
-
"""Return a list of file paths that end with the given extension in the specified
|
|
497
|
-
directory. Ignores certain special directories like .git, .datalad, derivatives,
|
|
498
|
-
and code.
|
|
499
|
-
"""
|
|
500
|
-
result_files = []
|
|
501
|
-
directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
|
|
502
|
-
with os.scandir(directory) as entries:
|
|
503
|
-
for entry in entries:
|
|
504
|
-
if entry.is_file() and entry.name.endswith(extension):
|
|
505
|
-
print("Adding ", entry.path)
|
|
506
|
-
result_files.append(entry.path)
|
|
507
|
-
elif entry.is_dir():
|
|
508
|
-
# check that entry path doesn't contain any name in ignore list
|
|
509
|
-
if not any(name in entry.name for name in directory_to_ignore):
|
|
510
|
-
result_files.append(entry.path) # Add directory to scan later
|
|
511
|
-
return result_files
|
|
512
|
-
|
|
513
|
-
def get_files_with_extension_parallel(
|
|
514
|
-
self, directory: str, extension: str = ".set", max_workers: int = -1
|
|
515
|
-
) -> list[Path]:
|
|
516
|
-
"""Efficiently scan a directory and its subdirectories for files that end with
|
|
517
|
-
the given extension.
|
|
518
|
-
|
|
519
|
-
Parameters
|
|
520
|
-
----------
|
|
521
|
-
directory : str
|
|
522
|
-
The root directory to scan for files.
|
|
523
|
-
extension : str
|
|
524
|
-
Only consider files that end with this suffix, e.g. '.set'.
|
|
525
|
-
max_workers : int
|
|
526
|
-
Optionally specify the maximum number of worker threads to use for parallel scanning.
|
|
527
|
-
Defaults to all available CPU cores if set to -1.
|
|
528
|
-
|
|
529
|
-
Returns
|
|
530
|
-
-------
|
|
531
|
-
list[Path]:
|
|
532
|
-
A list of filepaths for all matching metadata files
|
|
533
|
-
|
|
534
|
-
"""
|
|
535
|
-
result_files = []
|
|
536
|
-
dirs_to_scan = [directory]
|
|
537
|
-
|
|
538
|
-
# Use joblib.Parallel and delayed to parallelize directory scanning
|
|
539
|
-
while dirs_to_scan:
|
|
540
|
-
logger.info(
|
|
541
|
-
f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
|
|
542
|
-
)
|
|
543
|
-
# Run the scan_directory function in parallel across directories
|
|
544
|
-
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
|
|
545
|
-
delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
# Reset the directories to scan and process the results
|
|
549
|
-
dirs_to_scan = []
|
|
550
|
-
for res in results:
|
|
551
|
-
for path in res:
|
|
552
|
-
if os.path.isdir(path):
|
|
553
|
-
dirs_to_scan.append(path) # Queue up subdirectories to scan
|
|
554
|
-
else:
|
|
555
|
-
result_files.append(path) # Add files to the final result
|
|
556
|
-
logger.info(f"Found {len(result_files)} files.")
|
|
557
|
-
|
|
558
|
-
return result_files
|
|
559
|
-
|
|
560
|
-
def load_and_preprocess_raw(
|
|
561
|
-
self, raw_file: str, preprocess: bool = False
|
|
562
|
-
) -> np.ndarray:
|
|
563
|
-
"""Utility function to load a raw data file with MNE and apply some simple
|
|
564
|
-
(hardcoded) preprocessing and return as a numpy array. Not meant for purposes
|
|
565
|
-
other than testing or debugging.
|
|
566
|
-
"""
|
|
567
|
-
logger.info(f"Loading raw data from {raw_file}")
|
|
568
|
-
EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
|
|
569
|
-
|
|
570
|
-
if preprocess:
|
|
571
|
-
# highpass filter
|
|
572
|
-
EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
|
|
573
|
-
# remove 60Hz line noise
|
|
574
|
-
EEG = EEG.notch_filter(freqs=(60), verbose=False)
|
|
575
|
-
# bring to common sampling rate
|
|
576
|
-
sfreq = 128
|
|
577
|
-
if EEG.info["sfreq"] != sfreq:
|
|
578
|
-
EEG = EEG.resample(sfreq)
|
|
579
|
-
|
|
580
|
-
mat_data = EEG.get_data()
|
|
581
|
-
|
|
582
|
-
if len(mat_data.shape) > 2:
|
|
583
|
-
raise ValueError("Expect raw data to be CxT dimension")
|
|
584
|
-
return mat_data
|
|
585
|
-
|
|
586
|
-
def get_files(self) -> list[Path]:
|
|
587
|
-
"""Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
|
|
588
|
-
return self.files
|
|
589
|
-
|
|
590
|
-
def resolve_bids_json(self, json_files: list[str]) -> dict:
|
|
591
|
-
"""Resolve the BIDS JSON files and return a dictionary of the resolved values.
|
|
592
|
-
|
|
593
|
-
Parameters
|
|
594
|
-
----------
|
|
595
|
-
json_files : list
|
|
596
|
-
A list of JSON file paths to resolve in order of leaf level first.
|
|
597
|
-
|
|
598
|
-
Returns
|
|
599
|
-
-------
|
|
600
|
-
dict: A dictionary of the resolved values.
|
|
601
|
-
|
|
602
|
-
"""
|
|
603
|
-
if len(json_files) == 0:
|
|
604
|
-
raise ValueError("No JSON files provided")
|
|
605
|
-
json_files.reverse() # TODO undeterministic
|
|
606
|
-
|
|
607
|
-
json_dict = {}
|
|
608
|
-
for json_file in json_files:
|
|
609
|
-
with open(json_file) as f:
|
|
610
|
-
json_dict.update(json.load(f))
|
|
611
|
-
return json_dict
|
|
612
|
-
|
|
613
|
-
def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
|
|
614
|
-
"""Retrieve a specific attribute from the BIDS file metadata applicable
|
|
615
|
-
to the provided recording file path.
|
|
616
|
-
"""
|
|
617
|
-
entities = self.layout.parse_file_entities(data_filepath)
|
|
618
|
-
bidsfile = self.layout.get(**entities)[0]
|
|
619
|
-
attributes = bidsfile.get_entities(metadata="all")
|
|
620
|
-
attribute_mapping = {
|
|
621
|
-
"sfreq": "SamplingFrequency",
|
|
622
|
-
"modality": "datatype",
|
|
623
|
-
"task": "task",
|
|
624
|
-
"session": "session",
|
|
625
|
-
"run": "run",
|
|
626
|
-
"subject": "subject",
|
|
627
|
-
"ntimes": "RecordingDuration",
|
|
628
|
-
"nchans": "EEGChannelCount",
|
|
629
|
-
}
|
|
630
|
-
attribute_value = attributes.get(attribute_mapping.get(attribute), None)
|
|
631
|
-
return attribute_value
|
|
632
|
-
|
|
633
|
-
def channel_labels(self, data_filepath: str) -> list[str]:
|
|
634
|
-
"""Get a list of channel labels for the given data file path."""
|
|
635
|
-
channels_tsv = pd.read_csv(
|
|
636
|
-
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
637
|
-
)
|
|
638
|
-
return channels_tsv["name"].tolist()
|
|
639
|
-
|
|
640
|
-
def channel_types(self, data_filepath: str) -> list[str]:
|
|
641
|
-
"""Get a list of channel types for the given data file path."""
|
|
642
|
-
channels_tsv = pd.read_csv(
|
|
643
|
-
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
644
|
-
)
|
|
645
|
-
return channels_tsv["type"].tolist()
|
|
646
|
-
|
|
647
|
-
def num_times(self, data_filepath: str) -> int:
|
|
648
|
-
"""Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
|
|
649
|
-
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
650
|
-
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
651
|
-
return int(
|
|
652
|
-
eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
|
|
653
|
-
)
|
|
654
|
-
|
|
655
|
-
def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
656
|
-
"""Get BIDS participants.tsv record for the subject to which the given file
|
|
657
|
-
path corresponds, as a dictionary.
|
|
658
|
-
"""
|
|
659
|
-
participants_tsv = pd.read_csv(
|
|
660
|
-
self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
|
|
661
|
-
)
|
|
662
|
-
# if participants_tsv is not empty
|
|
663
|
-
if participants_tsv.empty:
|
|
664
|
-
return {}
|
|
665
|
-
# set 'participant_id' as index
|
|
666
|
-
participants_tsv.set_index("participant_id", inplace=True)
|
|
667
|
-
subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
|
|
668
|
-
return participants_tsv.loc[subject].to_dict()
|
|
669
|
-
|
|
670
|
-
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
|
|
671
|
-
"""Get BIDS eeg.json metadata for the given data file path."""
|
|
672
|
-
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
673
|
-
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
674
|
-
return eeg_json_dict
|
|
675
|
-
|
|
676
|
-
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
677
|
-
"""Get BIDS channels.tsv metadata for the given data file path, as a dictionary
|
|
678
|
-
of lists and/or single values.
|
|
679
|
-
"""
|
|
680
|
-
channels_tsv = pd.read_csv(
|
|
681
|
-
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
682
|
-
)
|
|
683
|
-
channel_tsv = channels_tsv.to_dict()
|
|
684
|
-
# 'name' and 'type' now have a dictionary of index-value. Convert them to list
|
|
685
|
-
for list_field in ["name", "type", "units"]:
|
|
686
|
-
channel_tsv[list_field] = list(channel_tsv[list_field].values())
|
|
687
|
-
return channel_tsv
|