eegdash 0.0.9__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +8 -1
- eegdash/api.py +690 -0
- eegdash/data_config.py +33 -27
- eegdash/data_utils.py +365 -222
- eegdash/dataset.py +60 -0
- eegdash/features/__init__.py +46 -18
- eegdash/features/datasets.py +62 -23
- eegdash/features/decorators.py +14 -6
- eegdash/features/extractors.py +22 -22
- eegdash/features/feature_bank/__init__.py +3 -3
- eegdash/features/feature_bank/complexity.py +6 -3
- eegdash/features/feature_bank/connectivity.py +16 -56
- eegdash/features/feature_bank/csp.py +3 -4
- eegdash/features/feature_bank/dimensionality.py +8 -5
- eegdash/features/feature_bank/signal.py +30 -4
- eegdash/features/feature_bank/spectral.py +10 -28
- eegdash/features/feature_bank/utils.py +48 -0
- eegdash/features/inspect.py +48 -0
- eegdash/features/serialization.py +4 -5
- eegdash/features/utils.py +9 -7
- eegdash/preprocessing.py +65 -0
- eegdash/utils.py +11 -0
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/METADATA +67 -20
- eegdash-0.2.0.dist-info/RECORD +27 -0
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/WHEEL +1 -1
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/licenses/LICENSE +1 -0
- eegdash/main.py +0 -359
- eegdash-0.0.9.dist-info/RECORD +0 -22
- {eegdash-0.0.9.dist-info → eegdash-0.2.0.dist-info}/top_level.txt +0 -0
eegdash/data_utils.py
CHANGED
|
@@ -1,90 +1,122 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
1
3
|
import os
|
|
2
|
-
import
|
|
3
|
-
from
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
4
8
|
import mne
|
|
9
|
+
import mne_bids
|
|
5
10
|
import numpy as np
|
|
6
11
|
import pandas as pd
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
import re
|
|
9
|
-
import json
|
|
10
|
-
from mne.io import BaseRaw
|
|
11
|
-
from mne._fiff.utils import _find_channels, _read_segments_file
|
|
12
12
|
import s3fs
|
|
13
|
-
import
|
|
13
|
+
from bids import BIDSLayout
|
|
14
|
+
from joblib import Parallel, delayed
|
|
14
15
|
from mne._fiff.utils import _read_segments_file
|
|
15
|
-
from
|
|
16
|
-
import mne_bids
|
|
16
|
+
from mne.io import BaseRaw
|
|
17
17
|
from mne_bids import (
|
|
18
18
|
BIDSPath,
|
|
19
19
|
)
|
|
20
|
-
from bids import BIDSLayout
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
"""Returns samples from an mne.io.Raw object along with a target.
|
|
21
|
+
from braindecode.datasets import BaseDataset
|
|
24
22
|
|
|
25
|
-
|
|
26
|
-
The target is unique for the dataset, and is obtained through the
|
|
27
|
-
`description` attribute.
|
|
23
|
+
logger = logging.getLogger("eegdash")
|
|
28
24
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
target_name : str | tuple | None
|
|
36
|
-
Name(s) of the index in `description` that should be used to provide the
|
|
37
|
-
target (e.g., to be used in a prediction task later on).
|
|
38
|
-
transform : callable | None
|
|
39
|
-
On-the-fly transform applied to the example before it is returned.
|
|
25
|
+
|
|
26
|
+
class EEGDashBaseDataset(BaseDataset):
|
|
27
|
+
"""A single EEG recording hosted on AWS S3 and cached locally upon first access.
|
|
28
|
+
|
|
29
|
+
This is a subclass of braindecode's BaseDataset, which can consequently be used in
|
|
30
|
+
conjunction with the preprocessing and training pipelines of braindecode.
|
|
40
31
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
32
|
+
|
|
33
|
+
AWS_BUCKET = "s3://openneuro.org"
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
record: dict[str, Any],
|
|
38
|
+
cache_dir: str,
|
|
39
|
+
s3_bucket: str | None = None,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
"""Create a new EEGDashBaseDataset instance. Users do not usually need to call this
|
|
43
|
+
directly -- instead use the EEGDashDataset class to load a collection of these
|
|
44
|
+
recordings from a local BIDS folder or using a database query.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
record : dict
|
|
49
|
+
A fully resolved metadata record for the data to load.
|
|
50
|
+
cache_dir : str
|
|
51
|
+
A local directory where the data will be cached.
|
|
52
|
+
kwargs : dict
|
|
53
|
+
Additional keyword arguments to pass to the BaseDataset constructor.
|
|
54
|
+
|
|
55
|
+
"""
|
|
43
56
|
super().__init__(None, **kwargs)
|
|
44
57
|
self.record = record
|
|
45
58
|
self.cache_dir = Path(cache_dir)
|
|
46
59
|
bids_kwargs = self.get_raw_bids_args()
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
|
|
50
|
-
|
|
60
|
+
|
|
61
|
+
self.bidspath = BIDSPath(
|
|
62
|
+
root=self.cache_dir / record["dataset"],
|
|
63
|
+
datatype="eeg",
|
|
64
|
+
suffix="eeg",
|
|
65
|
+
**bids_kwargs,
|
|
66
|
+
)
|
|
67
|
+
self.s3_bucket = s3_bucket if s3_bucket else self.AWS_BUCKET
|
|
68
|
+
self.s3file = self.get_s3path(record["bidspath"])
|
|
69
|
+
self.filecache = self.cache_dir / record["bidspath"]
|
|
70
|
+
self.bids_dependencies = record["bidsdependencies"]
|
|
51
71
|
self._raw = None
|
|
52
|
-
# if os.path.exists(self.filecache):
|
|
53
|
-
# self.raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
|
|
54
72
|
|
|
55
|
-
def get_s3path(self, filepath):
|
|
56
|
-
|
|
73
|
+
def get_s3path(self, filepath: str) -> str:
|
|
74
|
+
"""Helper to form an AWS S3 URI for the given relative filepath."""
|
|
75
|
+
return f"{self.s3_bucket}/{filepath}"
|
|
57
76
|
|
|
58
|
-
def _download_s3(self):
|
|
77
|
+
def _download_s3(self) -> None:
|
|
78
|
+
"""Fetch the given data from its S3 location and cache it locally."""
|
|
59
79
|
self.filecache.parent.mkdir(parents=True, exist_ok=True)
|
|
60
|
-
filesystem = s3fs.S3FileSystem(
|
|
80
|
+
filesystem = s3fs.S3FileSystem(
|
|
81
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
82
|
+
)
|
|
61
83
|
filesystem.download(self.s3file, self.filecache)
|
|
62
84
|
self.filenames = [self.filecache]
|
|
63
85
|
|
|
64
|
-
def _download_dependencies(self):
|
|
65
|
-
|
|
86
|
+
def _download_dependencies(self) -> None:
|
|
87
|
+
"""Download all BIDS dependency files (metadata files, recording sidecar files)
|
|
88
|
+
from S3 and cache them locally.
|
|
89
|
+
"""
|
|
90
|
+
filesystem = s3fs.S3FileSystem(
|
|
91
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
92
|
+
)
|
|
66
93
|
for dep in self.bids_dependencies:
|
|
67
94
|
s3path = self.get_s3path(dep)
|
|
68
95
|
filepath = self.cache_dir / dep
|
|
69
96
|
if not filepath.exists():
|
|
70
97
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
71
|
-
filesystem.download(s3path, filepath)
|
|
98
|
+
filesystem.download(s3path, filepath)
|
|
72
99
|
|
|
73
|
-
def get_raw_bids_args(self):
|
|
74
|
-
|
|
100
|
+
def get_raw_bids_args(self) -> dict[str, Any]:
|
|
101
|
+
"""Helper to restrict the metadata record to the fields needed to locate a BIDS
|
|
102
|
+
recording.
|
|
103
|
+
"""
|
|
104
|
+
desired_fields = ["subject", "session", "task", "run"]
|
|
75
105
|
return {k: self.record[k] for k in desired_fields if self.record[k]}
|
|
76
106
|
|
|
77
|
-
def check_and_get_raw(self):
|
|
78
|
-
|
|
107
|
+
def check_and_get_raw(self) -> None:
|
|
108
|
+
"""Download the S3 file and BIDS dependencies if not already cached."""
|
|
109
|
+
if not os.path.exists(self.filecache): # not preload
|
|
79
110
|
if self.bids_dependencies:
|
|
80
111
|
self._download_dependencies()
|
|
81
112
|
self._download_s3()
|
|
82
113
|
if self._raw is None:
|
|
83
114
|
self._raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
|
|
84
115
|
|
|
85
|
-
|
|
86
|
-
# self.check_and_get_raw()
|
|
116
|
+
# === BaseDataset and PyTorch Dataset interface ===
|
|
87
117
|
|
|
118
|
+
def __getitem__(self, index):
|
|
119
|
+
"""Main function to access a sample from the dataset."""
|
|
88
120
|
X = self.raw[:, index][0]
|
|
89
121
|
y = None
|
|
90
122
|
if self.target_name is not None:
|
|
@@ -94,15 +126,21 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
94
126
|
if self.transform is not None:
|
|
95
127
|
X = self.transform(X)
|
|
96
128
|
return X, y
|
|
97
|
-
|
|
98
|
-
def __len__(self):
|
|
129
|
+
|
|
130
|
+
def __len__(self) -> int:
|
|
131
|
+
"""Return the number of samples in the dataset."""
|
|
99
132
|
if self._raw is None:
|
|
100
|
-
|
|
133
|
+
# FIXME: this is a bit strange and should definitely not change as a side effect
|
|
134
|
+
# of accessing the data (which it will, since ntimes is the actual length but rounded down)
|
|
135
|
+
return int(self.record["ntimes"] * self.record["sampling_frequency"])
|
|
101
136
|
else:
|
|
102
137
|
return len(self._raw)
|
|
103
138
|
|
|
104
139
|
@property
|
|
105
140
|
def raw(self):
|
|
141
|
+
"""Return the MNE Raw object for this recording. This will perform the actual
|
|
142
|
+
retrieval if not yet done so.
|
|
143
|
+
"""
|
|
106
144
|
if self._raw is None:
|
|
107
145
|
self.check_and_get_raw()
|
|
108
146
|
return self._raw
|
|
@@ -111,59 +149,55 @@ class EEGDashBaseDataset(BaseDataset):
|
|
|
111
149
|
def raw(self, raw):
|
|
112
150
|
self._raw = raw
|
|
113
151
|
|
|
152
|
+
|
|
114
153
|
class EEGDashBaseRaw(BaseRaw):
|
|
115
|
-
|
|
154
|
+
"""Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
|
|
155
|
+
(when _read_segment is called) and caches it locally. Currently for internal use.
|
|
116
156
|
|
|
117
157
|
Parameters
|
|
118
158
|
----------
|
|
119
159
|
input_fname : path-like
|
|
120
160
|
Path to the S3 file
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
161
|
+
metadata : dict
|
|
162
|
+
The metadata record for the recording (e.g., from the database).
|
|
163
|
+
preload : bool
|
|
164
|
+
Whether to pre-loaded the data before the first access.
|
|
165
|
+
cache_dir : str
|
|
166
|
+
Local path under which the data will be cached.
|
|
167
|
+
bids_dependencies : list
|
|
168
|
+
List of additional BIDS metadata files that should be downloaded and cached
|
|
169
|
+
alongside the main recording file.
|
|
170
|
+
verbose : str | int | None
|
|
171
|
+
Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
|
|
131
172
|
|
|
132
173
|
See Also
|
|
133
174
|
--------
|
|
134
175
|
mne.io.Raw : Documentation of attributes and methods.
|
|
135
176
|
|
|
136
|
-
Notes
|
|
137
|
-
-----
|
|
138
|
-
.. versionadded:: 0.11.0
|
|
139
177
|
"""
|
|
140
178
|
|
|
141
|
-
AWS_BUCKET =
|
|
179
|
+
AWS_BUCKET = "s3://openneuro.org"
|
|
180
|
+
|
|
142
181
|
def __init__(
|
|
143
182
|
self,
|
|
144
|
-
input_fname,
|
|
145
|
-
metadata,
|
|
146
|
-
|
|
147
|
-
preload=False,
|
|
183
|
+
input_fname: str,
|
|
184
|
+
metadata: dict[str, Any],
|
|
185
|
+
preload: bool = False,
|
|
148
186
|
*,
|
|
149
|
-
cache_dir=
|
|
150
|
-
bids_dependencies:list = [],
|
|
151
|
-
|
|
152
|
-
montage_units="auto",
|
|
153
|
-
verbose=None,
|
|
187
|
+
cache_dir: str = "./.eegdash_cache",
|
|
188
|
+
bids_dependencies: list[str] = [],
|
|
189
|
+
verbose: Any = None,
|
|
154
190
|
):
|
|
155
|
-
|
|
156
|
-
Get to work with S3 endpoint first, no caching
|
|
157
|
-
'''
|
|
191
|
+
"""Get to work with S3 endpoint first, no caching"""
|
|
158
192
|
# Create a simple RawArray
|
|
159
|
-
sfreq = metadata[
|
|
160
|
-
n_times = metadata[
|
|
161
|
-
ch_names = metadata[
|
|
193
|
+
sfreq = metadata["sfreq"] # Sampling frequency
|
|
194
|
+
n_times = metadata["n_times"]
|
|
195
|
+
ch_names = metadata["ch_names"]
|
|
162
196
|
ch_types = []
|
|
163
|
-
for ch in metadata[
|
|
197
|
+
for ch in metadata["ch_types"]:
|
|
164
198
|
chtype = ch.lower()
|
|
165
|
-
if chtype ==
|
|
166
|
-
chtype =
|
|
199
|
+
if chtype == "heog" or chtype == "veog":
|
|
200
|
+
chtype = "eog"
|
|
167
201
|
ch_types.append(chtype)
|
|
168
202
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
169
203
|
self.s3file = self.get_s3path(input_fname)
|
|
@@ -178,7 +212,7 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
178
212
|
super().__init__(
|
|
179
213
|
info,
|
|
180
214
|
preload,
|
|
181
|
-
last_samps=[n_times-1],
|
|
215
|
+
last_samps=[n_times - 1],
|
|
182
216
|
orig_format="single",
|
|
183
217
|
verbose=verbose,
|
|
184
218
|
)
|
|
@@ -188,12 +222,16 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
188
222
|
|
|
189
223
|
def _download_s3(self):
|
|
190
224
|
self.filecache.parent.mkdir(parents=True, exist_ok=True)
|
|
191
|
-
filesystem = s3fs.S3FileSystem(
|
|
225
|
+
filesystem = s3fs.S3FileSystem(
|
|
226
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
227
|
+
)
|
|
192
228
|
filesystem.download(self.s3file, self.filecache)
|
|
193
229
|
self.filenames = [self.filecache]
|
|
194
230
|
|
|
195
231
|
def _download_dependencies(self):
|
|
196
|
-
filesystem = s3fs.S3FileSystem(
|
|
232
|
+
filesystem = s3fs.S3FileSystem(
|
|
233
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
234
|
+
)
|
|
197
235
|
for dep in self.bids_dependencies:
|
|
198
236
|
s3path = self.get_s3path(dep)
|
|
199
237
|
filepath = self.cache_dir / dep
|
|
@@ -204,34 +242,56 @@ class EEGDashBaseRaw(BaseRaw):
|
|
|
204
242
|
def _read_segment(
|
|
205
243
|
self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
|
|
206
244
|
):
|
|
207
|
-
if not os.path.exists(self.filecache):
|
|
245
|
+
if not os.path.exists(self.filecache): # not preload
|
|
208
246
|
if self.bids_dependencies:
|
|
209
247
|
self._download_dependencies()
|
|
210
248
|
self._download_s3()
|
|
211
|
-
else:
|
|
249
|
+
else: # not preload and file is not cached
|
|
212
250
|
self.filenames = [self.filecache]
|
|
213
251
|
return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
|
|
214
|
-
|
|
252
|
+
|
|
215
253
|
def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
|
|
216
254
|
"""Read a chunk of data from the file."""
|
|
217
255
|
_read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
|
|
218
256
|
|
|
219
257
|
|
|
220
|
-
class EEGBIDSDataset
|
|
221
|
-
|
|
258
|
+
class EEGBIDSDataset:
|
|
259
|
+
"""A one-stop shop interface to a local BIDS dataset containing EEG recordings.
|
|
260
|
+
|
|
261
|
+
This is mainly tailored to the needs of EEGDash application and is used to centralize
|
|
262
|
+
interactions with the BIDS dataset, such as parsing the metadata.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
data_dir : str | Path
|
|
267
|
+
The path to the local BIDS dataset directory.
|
|
268
|
+
dataset : str
|
|
269
|
+
A name for the dataset.
|
|
270
|
+
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
|
|
222
274
|
RAW_EXTENSIONS = {
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
METADATA_FILE_EXTENSIONS = [
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
275
|
+
".set": [".set", ".fdt"], # eeglab
|
|
276
|
+
".edf": [".edf"], # european
|
|
277
|
+
".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
|
|
278
|
+
".bdf": [".bdf"], # biosemi
|
|
279
|
+
}
|
|
280
|
+
METADATA_FILE_EXTENSIONS = [
|
|
281
|
+
"eeg.json",
|
|
282
|
+
"channels.tsv",
|
|
283
|
+
"electrodes.tsv",
|
|
284
|
+
"events.tsv",
|
|
285
|
+
"events.json",
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
data_dir=None, # location of bids dataset
|
|
291
|
+
dataset="", # dataset name
|
|
292
|
+
):
|
|
233
293
|
if data_dir is None or not os.path.exists(data_dir):
|
|
234
|
-
raise ValueError(
|
|
294
|
+
raise ValueError("data_dir must be specified and must exist")
|
|
235
295
|
self.bidsdir = Path(data_dir)
|
|
236
296
|
self.dataset = dataset
|
|
237
297
|
assert str(self.bidsdir).endswith(self.dataset)
|
|
@@ -239,73 +299,87 @@ class EEGBIDSDataset():
|
|
|
239
299
|
|
|
240
300
|
# get all recording files in the bids directory
|
|
241
301
|
self.files = self.get_recordings(self.layout)
|
|
242
|
-
assert len(self.files) > 0, ValueError(
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def check_eeg_dataset(self):
|
|
254
|
-
return self.get_bids_file_attribute('modality', self.files[0]).lower() == 'eeg'
|
|
255
|
-
|
|
256
|
-
def get_recordings(self, layout:BIDSLayout):
|
|
302
|
+
assert len(self.files) > 0, ValueError(
|
|
303
|
+
"Unable to construct EEG dataset. No EEG recordings found."
|
|
304
|
+
)
|
|
305
|
+
assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
|
|
306
|
+
|
|
307
|
+
def check_eeg_dataset(self) -> bool:
|
|
308
|
+
"""Check if the dataset is EEG."""
|
|
309
|
+
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
|
|
310
|
+
|
|
311
|
+
def get_recordings(self, layout: BIDSLayout) -> list[str]:
|
|
312
|
+
"""Get a list of all EEG recording files in the BIDS layout."""
|
|
257
313
|
files = []
|
|
258
314
|
for ext, exts in self.RAW_EXTENSIONS.items():
|
|
259
|
-
files = layout.get(extension=ext, return_type=
|
|
315
|
+
files = layout.get(extension=ext, return_type="filename")
|
|
260
316
|
if files:
|
|
261
|
-
break
|
|
317
|
+
break
|
|
262
318
|
return files
|
|
263
319
|
|
|
264
|
-
def get_relative_bidspath(self, filename):
|
|
265
|
-
|
|
320
|
+
def get_relative_bidspath(self, filename: str) -> str:
|
|
321
|
+
"""Make the given file path relative to the BIDS directory."""
|
|
322
|
+
bids_parent_dir = self.bidsdir.parent.absolute()
|
|
266
323
|
return str(Path(filename).relative_to(bids_parent_dir))
|
|
267
324
|
|
|
268
|
-
def get_property_from_filename(self, property, filename):
|
|
325
|
+
def get_property_from_filename(self, property: str, filename: str) -> str:
|
|
326
|
+
"""Parse a property out of a BIDS-compliant filename. Returns an empty string
|
|
327
|
+
if not found.
|
|
328
|
+
"""
|
|
269
329
|
import platform
|
|
330
|
+
|
|
270
331
|
if platform.system() == "Windows":
|
|
271
|
-
lookup = re.search(rf
|
|
332
|
+
lookup = re.search(rf"{property}-(.*?)[_\\]", filename)
|
|
272
333
|
else:
|
|
273
|
-
lookup = re.search(rf
|
|
274
|
-
return lookup.group(1) if lookup else
|
|
275
|
-
|
|
276
|
-
def merge_json_inheritance(self, json_files):
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
'''
|
|
334
|
+
lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
|
|
335
|
+
return lookup.group(1) if lookup else ""
|
|
336
|
+
|
|
337
|
+
def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
|
|
338
|
+
"""Internal helper to merge list of json files found by get_bids_file_inheritance,
|
|
339
|
+
expecting the order (from left to right) is from lowest
|
|
340
|
+
level to highest level, and return a merged dictionary
|
|
341
|
+
"""
|
|
282
342
|
json_files.reverse()
|
|
283
343
|
json_dict = {}
|
|
284
344
|
for f in json_files:
|
|
285
|
-
json_dict.update(json.load(open(f)))
|
|
345
|
+
json_dict.update(json.load(open(f))) # FIXME: should close file
|
|
286
346
|
return json_dict
|
|
287
347
|
|
|
288
|
-
def get_bids_file_inheritance(
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
348
|
+
def get_bids_file_inheritance(
|
|
349
|
+
self, path: str | Path, basename: str, extension: str
|
|
350
|
+
) -> list[Path]:
|
|
351
|
+
"""Get all file paths that apply to the basename file in the specified directory
|
|
352
|
+
and that end with the specified suffix, recursively searching parent directories
|
|
353
|
+
(following the BIDS inheritance principle in the order of lowest level first).
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
path : str | Path
|
|
358
|
+
The directory path to search for files.
|
|
359
|
+
basename : str
|
|
360
|
+
BIDS file basename without _eeg.set extension for example
|
|
361
|
+
extension : str
|
|
362
|
+
Only consider files that end with the specified suffix; e.g. channels.tsv
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
list[Path]
|
|
367
|
+
A list of file paths that match the given basename and extension.
|
|
368
|
+
|
|
369
|
+
"""
|
|
370
|
+
top_level_files = ["README", "dataset_description.json", "participants.tsv"]
|
|
297
371
|
bids_files = []
|
|
298
372
|
|
|
299
373
|
# check if path is str object
|
|
300
374
|
if isinstance(path, str):
|
|
301
375
|
path = Path(path)
|
|
302
376
|
if not path.exists:
|
|
303
|
-
raise ValueError(
|
|
377
|
+
raise ValueError("path {path} does not exist")
|
|
304
378
|
|
|
305
379
|
# check if file is in current path
|
|
306
380
|
for file in os.listdir(path):
|
|
307
381
|
# target_file = path / f"{cur_file_basename}_{extension}"
|
|
308
|
-
if os.path.isfile(path/file):
|
|
382
|
+
if os.path.isfile(path / file):
|
|
309
383
|
# check if file has extension extension
|
|
310
384
|
# check if file basename has extension
|
|
311
385
|
if file.endswith(extension):
|
|
@@ -317,38 +391,54 @@ class EEGBIDSDataset():
|
|
|
317
391
|
return bids_files
|
|
318
392
|
else:
|
|
319
393
|
# call get_bids_file_inheritance recursively with parent directory
|
|
320
|
-
bids_files.extend(
|
|
394
|
+
bids_files.extend(
|
|
395
|
+
self.get_bids_file_inheritance(path.parent, basename, extension)
|
|
396
|
+
)
|
|
321
397
|
return bids_files
|
|
322
398
|
|
|
323
|
-
def get_bids_metadata_files(
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
399
|
+
def get_bids_metadata_files(
|
|
400
|
+
self, filepath: str | Path, metadata_file_extension: list[str]
|
|
401
|
+
) -> list[Path]:
|
|
402
|
+
"""Retrieve all metadata file paths that apply to a given data file path and that
|
|
403
|
+
end with a specific suffix (following the BIDS inheritance principle).
|
|
404
|
+
|
|
405
|
+
Parameters
|
|
406
|
+
----------
|
|
407
|
+
filepath: str | Path
|
|
408
|
+
The filepath to get the associated metadata files for.
|
|
409
|
+
metadata_file_extension : str
|
|
410
|
+
Consider only metadata files that end with the specified suffix,
|
|
411
|
+
e.g., channels.tsv or eeg.json
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
list[Path]:
|
|
416
|
+
A list of filepaths for all matching metadata files
|
|
417
|
+
|
|
334
418
|
"""
|
|
335
419
|
if isinstance(filepath, str):
|
|
336
420
|
filepath = Path(filepath)
|
|
337
421
|
if not filepath.exists:
|
|
338
|
-
raise ValueError(
|
|
422
|
+
raise ValueError("filepath {filepath} does not exist")
|
|
339
423
|
path, filename = os.path.split(filepath)
|
|
340
|
-
basename = filename[:filename.rfind(
|
|
424
|
+
basename = filename[: filename.rfind("_")]
|
|
341
425
|
# metadata files
|
|
342
|
-
meta_files = self.get_bids_file_inheritance(
|
|
426
|
+
meta_files = self.get_bids_file_inheritance(
|
|
427
|
+
path, basename, metadata_file_extension
|
|
428
|
+
)
|
|
343
429
|
return meta_files
|
|
344
|
-
|
|
345
|
-
def scan_directory(self, directory, extension):
|
|
430
|
+
|
|
431
|
+
def scan_directory(self, directory: str, extension: str) -> list[Path]:
|
|
432
|
+
"""Return a list of file paths that end with the given extension in the specified
|
|
433
|
+
directory. Ignores certain special directories like .git, .datalad, derivatives,
|
|
434
|
+
and code.
|
|
435
|
+
"""
|
|
346
436
|
result_files = []
|
|
347
|
-
directory_to_ignore = [
|
|
437
|
+
directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
|
|
348
438
|
with os.scandir(directory) as entries:
|
|
349
439
|
for entry in entries:
|
|
350
440
|
if entry.is_file() and entry.name.endswith(extension):
|
|
351
|
-
print(
|
|
441
|
+
print("Adding ", entry.path)
|
|
352
442
|
result_files.append(entry.path)
|
|
353
443
|
elif entry.is_dir():
|
|
354
444
|
# check that entry path doesn't contain any name in ignore list
|
|
@@ -356,18 +446,41 @@ class EEGBIDSDataset():
|
|
|
356
446
|
result_files.append(entry.path) # Add directory to scan later
|
|
357
447
|
return result_files
|
|
358
448
|
|
|
359
|
-
def get_files_with_extension_parallel(
|
|
449
|
+
def get_files_with_extension_parallel(
|
|
450
|
+
self, directory: str, extension: str = ".set", max_workers: int = -1
|
|
451
|
+
) -> list[Path]:
|
|
452
|
+
"""Efficiently scan a directory and its subdirectories for files that end with
|
|
453
|
+
the given extension.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
directory : str
|
|
458
|
+
The root directory to scan for files.
|
|
459
|
+
extension : str
|
|
460
|
+
Only consider files that end with this suffix, e.g. '.set'.
|
|
461
|
+
max_workers : int
|
|
462
|
+
Optionally specify the maximum number of worker threads to use for parallel scanning.
|
|
463
|
+
Defaults to all available CPU cores if set to -1.
|
|
464
|
+
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
list[Path]:
|
|
468
|
+
A list of filepaths for all matching metadata files
|
|
469
|
+
|
|
470
|
+
"""
|
|
360
471
|
result_files = []
|
|
361
472
|
dirs_to_scan = [directory]
|
|
362
473
|
|
|
363
474
|
# Use joblib.Parallel and delayed to parallelize directory scanning
|
|
364
475
|
while dirs_to_scan:
|
|
365
|
-
|
|
476
|
+
logger.info(
|
|
477
|
+
f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
|
|
478
|
+
)
|
|
366
479
|
# Run the scan_directory function in parallel across directories
|
|
367
480
|
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
|
|
368
481
|
delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
|
|
369
482
|
)
|
|
370
|
-
|
|
483
|
+
|
|
371
484
|
# Reset the directories to scan and process the results
|
|
372
485
|
dirs_to_scan = []
|
|
373
486
|
for res in results:
|
|
@@ -376,14 +489,20 @@ class EEGBIDSDataset():
|
|
|
376
489
|
dirs_to_scan.append(path) # Queue up subdirectories to scan
|
|
377
490
|
else:
|
|
378
491
|
result_files.append(path) # Add files to the final result
|
|
379
|
-
|
|
492
|
+
logger.info(f"Found {len(result_files)} files.")
|
|
380
493
|
|
|
381
494
|
return result_files
|
|
382
495
|
|
|
383
|
-
def load_and_preprocess_raw(
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
496
|
+
def load_and_preprocess_raw(
|
|
497
|
+
self, raw_file: str, preprocess: bool = False
|
|
498
|
+
) -> np.ndarray:
|
|
499
|
+
"""Utility function to load a raw data file with MNE and apply some simple
|
|
500
|
+
(hardcoded) preprocessing and return as a numpy array. Not meant for purposes
|
|
501
|
+
other than testing or debugging.
|
|
502
|
+
"""
|
|
503
|
+
logger.info(f"Loading raw data from {raw_file}")
|
|
504
|
+
EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
|
|
505
|
+
|
|
387
506
|
if preprocess:
|
|
388
507
|
# highpass filter
|
|
389
508
|
EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
|
|
@@ -391,33 +510,35 @@ class EEGBIDSDataset():
|
|
|
391
510
|
EEG = EEG.notch_filter(freqs=(60), verbose=False)
|
|
392
511
|
# bring to common sampling rate
|
|
393
512
|
sfreq = 128
|
|
394
|
-
if EEG.info[
|
|
513
|
+
if EEG.info["sfreq"] != sfreq:
|
|
395
514
|
EEG = EEG.resample(sfreq)
|
|
396
|
-
# # normalize data to zero mean and unit variance
|
|
397
|
-
# scalar = preprocessing.StandardScaler()
|
|
398
|
-
# mat_data = scalar.fit_transform(mat_data.T).T # scalar normalize for each feature and expects shape data x features
|
|
399
515
|
|
|
400
516
|
mat_data = EEG.get_data()
|
|
401
517
|
|
|
402
518
|
if len(mat_data.shape) > 2:
|
|
403
|
-
raise ValueError(
|
|
519
|
+
raise ValueError("Expect raw data to be CxT dimension")
|
|
404
520
|
return mat_data
|
|
405
|
-
|
|
406
|
-
def get_files(self):
|
|
521
|
+
|
|
522
|
+
def get_files(self) -> list[Path]:
|
|
523
|
+
"""Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
|
|
407
524
|
return self.files
|
|
408
|
-
|
|
409
|
-
def resolve_bids_json(self, json_files: list):
|
|
410
|
-
"""
|
|
411
|
-
Resolve the BIDS JSON files and return a dictionary of the resolved values.
|
|
412
|
-
Args:
|
|
413
|
-
json_files (list): A list of JSON files to resolve in order of leaf level first
|
|
414
525
|
|
|
415
|
-
|
|
526
|
+
def resolve_bids_json(self, json_files: list[str]) -> dict:
|
|
527
|
+
"""Resolve the BIDS JSON files and return a dictionary of the resolved values.
|
|
528
|
+
|
|
529
|
+
Parameters
|
|
530
|
+
----------
|
|
531
|
+
json_files : list
|
|
532
|
+
A list of JSON file paths to resolve in order of leaf level first.
|
|
533
|
+
|
|
534
|
+
Returns
|
|
535
|
+
-------
|
|
416
536
|
dict: A dictionary of the resolved values.
|
|
537
|
+
|
|
417
538
|
"""
|
|
418
539
|
if len(json_files) == 0:
|
|
419
|
-
raise ValueError(
|
|
420
|
-
json_files.reverse()
|
|
540
|
+
raise ValueError("No JSON files provided")
|
|
541
|
+
json_files.reverse() # TODO undeterministic
|
|
421
542
|
|
|
422
543
|
json_dict = {}
|
|
423
544
|
for json_file in json_files:
|
|
@@ -425,56 +546,78 @@ class EEGBIDSDataset():
|
|
|
425
546
|
json_dict.update(json.load(f))
|
|
426
547
|
return json_dict
|
|
427
548
|
|
|
428
|
-
def get_bids_file_attribute(self, attribute, data_filepath):
|
|
549
|
+
def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
|
|
550
|
+
"""Retrieve a specific attribute from the BIDS file metadata applicable
|
|
551
|
+
to the provided recording file path.
|
|
552
|
+
"""
|
|
429
553
|
entities = self.layout.parse_file_entities(data_filepath)
|
|
430
554
|
bidsfile = self.layout.get(**entities)[0]
|
|
431
|
-
attributes = bidsfile.get_entities(metadata=
|
|
555
|
+
attributes = bidsfile.get_entities(metadata="all")
|
|
432
556
|
attribute_mapping = {
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
557
|
+
"sfreq": "SamplingFrequency",
|
|
558
|
+
"modality": "datatype",
|
|
559
|
+
"task": "task",
|
|
560
|
+
"session": "session",
|
|
561
|
+
"run": "run",
|
|
562
|
+
"subject": "subject",
|
|
563
|
+
"ntimes": "RecordingDuration",
|
|
564
|
+
"nchans": "EEGChannelCount",
|
|
441
565
|
}
|
|
442
566
|
attribute_value = attributes.get(attribute_mapping.get(attribute), None)
|
|
443
567
|
return attribute_value
|
|
444
568
|
|
|
445
|
-
def channel_labels(self, data_filepath):
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
channels_tsv
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
569
|
+
def channel_labels(self, data_filepath: str) -> list[str]:
|
|
570
|
+
"""Get a list of channel labels for the given data file path."""
|
|
571
|
+
channels_tsv = pd.read_csv(
|
|
572
|
+
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
573
|
+
)
|
|
574
|
+
return channels_tsv["name"].tolist()
|
|
575
|
+
|
|
576
|
+
def channel_types(self, data_filepath: str) -> list[str]:
|
|
577
|
+
"""Get a list of channel types for the given data file path."""
|
|
578
|
+
channels_tsv = pd.read_csv(
|
|
579
|
+
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
580
|
+
)
|
|
581
|
+
return channels_tsv["type"].tolist()
|
|
582
|
+
|
|
583
|
+
def num_times(self, data_filepath: str) -> int:
|
|
584
|
+
"""Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
|
|
585
|
+
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
455
586
|
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
456
|
-
return int(
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
587
|
+
return int(
|
|
588
|
+
eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
592
|
+
"""Get BIDS participants.tsv record for the subject to which the given file
|
|
593
|
+
path corresponds, as a dictionary.
|
|
594
|
+
"""
|
|
595
|
+
participants_tsv = pd.read_csv(
|
|
596
|
+
self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
|
|
597
|
+
)
|
|
461
598
|
# if participants_tsv is not empty
|
|
462
599
|
if participants_tsv.empty:
|
|
463
600
|
return {}
|
|
464
601
|
# set 'participant_id' as index
|
|
465
|
-
participants_tsv.set_index(
|
|
602
|
+
participants_tsv.set_index("participant_id", inplace=True)
|
|
466
603
|
subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
|
|
467
604
|
return participants_tsv.loc[subject].to_dict()
|
|
468
|
-
|
|
469
|
-
def eeg_json(self, data_filepath):
|
|
470
|
-
|
|
605
|
+
|
|
606
|
+
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
|
|
607
|
+
"""Get BIDS eeg.json metadata for the given data file path."""
|
|
608
|
+
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
|
|
471
609
|
eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
|
|
472
610
|
return eeg_json_dict
|
|
473
|
-
|
|
474
|
-
def channel_tsv(self, data_filepath):
|
|
475
|
-
|
|
611
|
+
|
|
612
|
+
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
|
|
613
|
+
"""Get BIDS channels.tsv metadata for the given data file path, as a dictionary
|
|
614
|
+
of lists and/or single values.
|
|
615
|
+
"""
|
|
616
|
+
channels_tsv = pd.read_csv(
|
|
617
|
+
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
|
|
618
|
+
)
|
|
476
619
|
channel_tsv = channels_tsv.to_dict()
|
|
477
620
|
# 'name' and 'type' now have a dictionary of index-value. Convert them to list
|
|
478
|
-
for list_field in [
|
|
621
|
+
for list_field in ["name", "type", "units"]:
|
|
479
622
|
channel_tsv[list_field] = list(channel_tsv[list_field].values())
|
|
480
|
-
return channel_tsv
|
|
623
|
+
return channel_tsv
|