eegdash 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +7 -3
- eegdash/api.py +690 -0
- eegdash/data_config.py +7 -1
- eegdash/data_utils.py +215 -118
- eegdash/dataset.py +60 -0
- eegdash/features/__init__.py +37 -9
- eegdash/features/datasets.py +57 -21
- eegdash/features/decorators.py +10 -2
- eegdash/features/extractors.py +20 -21
- eegdash/features/feature_bank/complexity.py +4 -0
- eegdash/features/feature_bank/csp.py +2 -2
- eegdash/features/feature_bank/dimensionality.py +7 -3
- eegdash/features/feature_bank/signal.py +29 -3
- eegdash/features/inspect.py +48 -0
- eegdash/features/serialization.py +2 -3
- eegdash/features/utils.py +1 -1
- eegdash/preprocessing.py +65 -0
- eegdash/utils.py +11 -0
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/METADATA +49 -6
- eegdash-0.2.0.dist-info/RECORD +27 -0
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/WHEEL +1 -1
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/licenses/LICENSE +1 -0
- eegdash/main.py +0 -416
- eegdash-0.1.0.dist-info/RECORD +0 -23
- {eegdash-0.1.0.dist-info → eegdash-0.2.0.dist-info}/top_level.txt +0 -0
eegdash/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .api import EEGDash, EEGDashDataset
|
|
2
|
+
from .dataset import EEGChallengeDataset
|
|
3
|
+
from .utils import __init__mongo_client
|
|
2
4
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
+
__init__mongo_client()
|
|
6
|
+
|
|
7
|
+
__all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset"]
|
|
8
|
+
__version__ = "0.2.0"
|
eegdash/api.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Mapping
|
|
6
|
+
|
|
7
|
+
import mne
|
|
8
|
+
import numpy as np
|
|
9
|
+
import xarray as xr
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
from joblib import Parallel, delayed
|
|
12
|
+
from pymongo import InsertOne, MongoClient, UpdateOne
|
|
13
|
+
from s3fs import S3FileSystem
|
|
14
|
+
|
|
15
|
+
from braindecode.datasets import BaseConcatDataset
|
|
16
|
+
|
|
17
|
+
from .data_config import config as data_config
|
|
18
|
+
from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("eegdash")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EEGDash:
|
|
24
|
+
"""A high-level interface to the EEGDash database.
|
|
25
|
+
|
|
26
|
+
This class is primarily used to interact with the metadata records stored in the
|
|
27
|
+
EEGDash database (or a private instance of it), allowing users to find, add, and
|
|
28
|
+
update EEG data records.
|
|
29
|
+
|
|
30
|
+
While this class provides basic support for loading EEG data, please see
|
|
31
|
+
the EEGDashDataset class for a more complete way to retrieve and work with full
|
|
32
|
+
datasets.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
AWS_BUCKET = "s3://openneuro.org"
|
|
37
|
+
|
|
38
|
+
def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
|
|
39
|
+
"""Create new instance of the EEGDash Database client.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
is_public: bool
|
|
44
|
+
Whether to connect to the public MongoDB database; if False, connect to a
|
|
45
|
+
private database instance as per the DB_CONNECTION_STRING env variable
|
|
46
|
+
(or .env file entry).
|
|
47
|
+
is_staging: bool
|
|
48
|
+
If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
|
|
49
|
+
production database ("eegdash").
|
|
50
|
+
|
|
51
|
+
Example
|
|
52
|
+
-------
|
|
53
|
+
>>> eegdash = EEGDash()
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
self.config = data_config
|
|
57
|
+
self.is_public = is_public
|
|
58
|
+
|
|
59
|
+
if self.is_public:
|
|
60
|
+
DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
|
|
61
|
+
else:
|
|
62
|
+
load_dotenv()
|
|
63
|
+
DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
|
|
64
|
+
|
|
65
|
+
self.__client = MongoClient(DB_CONNECTION_STRING)
|
|
66
|
+
self.__db = (
|
|
67
|
+
self.__client["eegdash"]
|
|
68
|
+
if not is_staging
|
|
69
|
+
else self.__client["eegdashstaging"]
|
|
70
|
+
)
|
|
71
|
+
self.__collection = self.__db["records"]
|
|
72
|
+
|
|
73
|
+
self.filesystem = S3FileSystem(
|
|
74
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# MongoDB Operations
|
|
78
|
+
# These methods provide a high-level interface to interact with the MongoDB
|
|
79
|
+
# collection, allowing users to find, add, and update EEG data records.
|
|
80
|
+
# - find:
|
|
81
|
+
# - exist:
|
|
82
|
+
# - add_request:
|
|
83
|
+
# - add:
|
|
84
|
+
# - update_request:
|
|
85
|
+
# - remove_field:
|
|
86
|
+
# - remove_field_from_db:
|
|
87
|
+
# - close: Close the MongoDB connection.
|
|
88
|
+
# - __del__: Destructor to close the MongoDB connection.
|
|
89
|
+
|
|
90
|
+
def find(self, query: dict[str, Any], *args, **kwargs) -> list[Mapping[str, Any]]:
|
|
91
|
+
"""Find records in the MongoDB collection that satisfy the given query.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
query: dict
|
|
96
|
+
A dictionary that specifies the query to be executed; this is a reference
|
|
97
|
+
document that is used to match records in the MongoDB collection.
|
|
98
|
+
args:
|
|
99
|
+
Additional positional arguments for the MongoDB find() method; see
|
|
100
|
+
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
|
|
101
|
+
kwargs:
|
|
102
|
+
Additional keyword arguments for the MongoDB find() method.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
list:
|
|
107
|
+
A list of DB records (string-keyed dictionaries) that match the query.
|
|
108
|
+
|
|
109
|
+
Example
|
|
110
|
+
-------
|
|
111
|
+
>>> eegdash = EEGDash()
|
|
112
|
+
>>> eegdash.find({"dataset": "ds002718", "subject": "012"})
|
|
113
|
+
|
|
114
|
+
"""
|
|
115
|
+
results = self.__collection.find(query, *args, **kwargs)
|
|
116
|
+
|
|
117
|
+
return [result for result in results]
|
|
118
|
+
|
|
119
|
+
def exist(self, query: dict[str, Any]) -> bool:
|
|
120
|
+
"""Check if the given query matches any records in the MongoDB collection.
|
|
121
|
+
|
|
122
|
+
Note that currently only a limited set of query fields is allowed here.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
query: dict
|
|
127
|
+
A dictionary that specifies the query to be executed; this is a reference
|
|
128
|
+
document that is used to match records in the MongoDB collection.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
bool:
|
|
133
|
+
True if at least one record matches the query, False otherwise.
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
accepted_query_fields = ["data_name", "dataset"]
|
|
137
|
+
assert all(field in accepted_query_fields for field in query.keys())
|
|
138
|
+
sessions = self.find(query)
|
|
139
|
+
return len(sessions) > 0
|
|
140
|
+
|
|
141
|
+
def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
|
|
142
|
+
"""Internal method to validate the input record against the expected schema.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
record: dict
|
|
147
|
+
A dictionary representing the EEG data record to be validated.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
dict:
|
|
152
|
+
Returns the record itself on success, or raises a ValueError if the record is invalid.
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
input_types = {
|
|
156
|
+
"data_name": str,
|
|
157
|
+
"dataset": str,
|
|
158
|
+
"bidspath": str,
|
|
159
|
+
"subject": str,
|
|
160
|
+
"task": str,
|
|
161
|
+
"session": str,
|
|
162
|
+
"run": str,
|
|
163
|
+
"sampling_frequency": float,
|
|
164
|
+
"modality": str,
|
|
165
|
+
"nchans": int,
|
|
166
|
+
"ntimes": int,
|
|
167
|
+
"channel_types": list,
|
|
168
|
+
"channel_names": list,
|
|
169
|
+
}
|
|
170
|
+
if "data_name" not in record:
|
|
171
|
+
raise ValueError("Missing key: data_name")
|
|
172
|
+
# check if args are in the keys and has correct type
|
|
173
|
+
for key, value in record.items():
|
|
174
|
+
if key not in input_types:
|
|
175
|
+
raise ValueError(f"Invalid input: {key}")
|
|
176
|
+
if not isinstance(value, input_types[key]):
|
|
177
|
+
raise ValueError(f"Invalid input: {key}")
|
|
178
|
+
|
|
179
|
+
return record
|
|
180
|
+
|
|
181
|
+
def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
|
|
182
|
+
"""Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
s3path : str
|
|
187
|
+
An S3 URI (should start with "s3://") for the file in question.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
xr.DataArray
|
|
192
|
+
A DataArray containing the EEG data, with dimensions "channel" and "time".
|
|
193
|
+
|
|
194
|
+
Example
|
|
195
|
+
-------
|
|
196
|
+
>>> eegdash = EEGDash()
|
|
197
|
+
>>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
|
|
198
|
+
>>> mydata = eegdash.load_eeg_data_from_s3(mypath)
|
|
199
|
+
|
|
200
|
+
"""
|
|
201
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
|
|
202
|
+
with self.filesystem.open(s3path) as s3_file:
|
|
203
|
+
tmp.write(s3_file.read())
|
|
204
|
+
tmp_path = tmp.name
|
|
205
|
+
eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
|
|
206
|
+
os.unlink(tmp_path)
|
|
207
|
+
return eeg_data
|
|
208
|
+
|
|
209
|
+
def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
|
|
210
|
+
"""Load EEG data from a local file and return it as a xarray DataArray.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
bids_file : str
|
|
215
|
+
Path to the file on the local filesystem.
|
|
216
|
+
|
|
217
|
+
Notes
|
|
218
|
+
-----
|
|
219
|
+
Currently, only non-epoched .set files are supported.
|
|
220
|
+
|
|
221
|
+
"""
|
|
222
|
+
raw_object = mne.io.read_raw(bids_file)
|
|
223
|
+
eeg_data = raw_object.get_data()
|
|
224
|
+
|
|
225
|
+
fs = raw_object.info["sfreq"]
|
|
226
|
+
max_time = eeg_data.shape[1] / fs
|
|
227
|
+
time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
|
|
228
|
+
|
|
229
|
+
channel_names = raw_object.ch_names
|
|
230
|
+
|
|
231
|
+
eeg_xarray = xr.DataArray(
|
|
232
|
+
data=eeg_data,
|
|
233
|
+
dims=["channel", "time"],
|
|
234
|
+
coords={"time": time_steps, "channel": channel_names},
|
|
235
|
+
)
|
|
236
|
+
return eeg_xarray
|
|
237
|
+
|
|
238
|
+
def get_raw_extensions(
|
|
239
|
+
self, bids_file: str, bids_dataset: EEGBIDSDataset
|
|
240
|
+
) -> list[str]:
|
|
241
|
+
"""Helper to find paths to additional "sidecar" files that may be associated
|
|
242
|
+
with a given main data file in a BIDS dataset; paths are returned as relative to
|
|
243
|
+
the parent dataset path.
|
|
244
|
+
|
|
245
|
+
For example, if the input file is a .set file, this will return the relative path
|
|
246
|
+
to a corresponding .fdt file (if any).
|
|
247
|
+
"""
|
|
248
|
+
bids_file = Path(bids_file)
|
|
249
|
+
extensions = {
|
|
250
|
+
".set": [".set", ".fdt"], # eeglab
|
|
251
|
+
".edf": [".edf"], # european
|
|
252
|
+
".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
|
|
253
|
+
".bdf": [".bdf"], # biosemi
|
|
254
|
+
}
|
|
255
|
+
return [
|
|
256
|
+
str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
|
|
257
|
+
for suffix in extensions[bids_file.suffix]
|
|
258
|
+
if bids_file.with_suffix(suffix).exists()
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
def load_eeg_attrs_from_bids_file(
|
|
262
|
+
self, bids_dataset: EEGBIDSDataset, bids_file: str
|
|
263
|
+
) -> dict[str, Any]:
|
|
264
|
+
"""Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
|
|
265
|
+
|
|
266
|
+
Attributes are at least the ones defined in data_config attributes (set to None if missing),
|
|
267
|
+
but are typically a superset, and include, among others, the paths to relevant
|
|
268
|
+
meta-data files needed to load and interpret the file in question.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
bids_dataset : EEGBIDSDataset
|
|
273
|
+
The BIDS dataset object containing the file.
|
|
274
|
+
bids_file : str
|
|
275
|
+
The path to the BIDS file within the dataset.
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
dict:
|
|
280
|
+
A dictionary representing the metadata record for the given file. This is the
|
|
281
|
+
same format as the records stored in the database.
|
|
282
|
+
|
|
283
|
+
"""
|
|
284
|
+
if bids_file not in bids_dataset.files:
|
|
285
|
+
raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
|
|
286
|
+
|
|
287
|
+
# Initialize attrs with None values for all expected fields
|
|
288
|
+
attrs = {field: None for field in self.config["attributes"].keys()}
|
|
289
|
+
|
|
290
|
+
file = Path(bids_file).name
|
|
291
|
+
dsnumber = bids_dataset.dataset
|
|
292
|
+
# extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
|
|
293
|
+
openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
|
|
294
|
+
|
|
295
|
+
# Update with actual values where available
|
|
296
|
+
try:
|
|
297
|
+
participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.error("Error getting participants_tsv: %s", str(e))
|
|
300
|
+
participants_tsv = None
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
eeg_json = bids_dataset.eeg_json(bids_file)
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error("Error getting eeg_json: %s", str(e))
|
|
306
|
+
eeg_json = None
|
|
307
|
+
|
|
308
|
+
bids_dependencies_files = self.config["bids_dependencies_files"]
|
|
309
|
+
bidsdependencies = []
|
|
310
|
+
for extension in bids_dependencies_files:
|
|
311
|
+
try:
|
|
312
|
+
dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
|
|
313
|
+
dep_path = [
|
|
314
|
+
str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
|
|
315
|
+
]
|
|
316
|
+
bidsdependencies.extend(dep_path)
|
|
317
|
+
except Exception:
|
|
318
|
+
pass
|
|
319
|
+
|
|
320
|
+
bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
|
|
321
|
+
|
|
322
|
+
# Define field extraction functions with error handling
|
|
323
|
+
field_extractors = {
|
|
324
|
+
"data_name": lambda: f"{bids_dataset.dataset}_{file}",
|
|
325
|
+
"dataset": lambda: bids_dataset.dataset,
|
|
326
|
+
"bidspath": lambda: openneuro_path,
|
|
327
|
+
"subject": lambda: bids_dataset.get_bids_file_attribute(
|
|
328
|
+
"subject", bids_file
|
|
329
|
+
),
|
|
330
|
+
"task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
|
|
331
|
+
"session": lambda: bids_dataset.get_bids_file_attribute(
|
|
332
|
+
"session", bids_file
|
|
333
|
+
),
|
|
334
|
+
"run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
|
|
335
|
+
"modality": lambda: bids_dataset.get_bids_file_attribute(
|
|
336
|
+
"modality", bids_file
|
|
337
|
+
),
|
|
338
|
+
"sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
|
|
339
|
+
"sfreq", bids_file
|
|
340
|
+
),
|
|
341
|
+
"nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
|
|
342
|
+
"ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
|
|
343
|
+
"participant_tsv": lambda: participants_tsv,
|
|
344
|
+
"eeg_json": lambda: eeg_json,
|
|
345
|
+
"bidsdependencies": lambda: bidsdependencies,
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
# Dynamically populate attrs with error handling
|
|
349
|
+
for field, extractor in field_extractors.items():
|
|
350
|
+
try:
|
|
351
|
+
attrs[field] = extractor()
|
|
352
|
+
except Exception as e:
|
|
353
|
+
logger.error("Error extracting %s : %s", field, str(e))
|
|
354
|
+
attrs[field] = None
|
|
355
|
+
|
|
356
|
+
return attrs
|
|
357
|
+
|
|
358
|
+
def add_bids_dataset(
|
|
359
|
+
self, dataset: str, data_dir: str, overwrite: bool = True
|
|
360
|
+
) -> None:
|
|
361
|
+
"""Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
|
|
362
|
+
under the given dataset name.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
dataset : str)
|
|
367
|
+
The name of the dataset to be added (e.g., "ds002718").
|
|
368
|
+
data_dir : str
|
|
369
|
+
The path to the BIDS dataset directory.
|
|
370
|
+
overwrite : bool
|
|
371
|
+
Whether to overwrite/update existing records in the database.
|
|
372
|
+
|
|
373
|
+
"""
|
|
374
|
+
if self.is_public:
|
|
375
|
+
raise ValueError("This operation is not allowed for public users")
|
|
376
|
+
|
|
377
|
+
if not overwrite and self.exist({"dataset": dataset}):
|
|
378
|
+
logger.info("Dataset %s already exists in the database", dataset)
|
|
379
|
+
return
|
|
380
|
+
try:
|
|
381
|
+
bids_dataset = EEGBIDSDataset(
|
|
382
|
+
data_dir=data_dir,
|
|
383
|
+
dataset=dataset,
|
|
384
|
+
)
|
|
385
|
+
except Exception as e:
|
|
386
|
+
logger.error("Error creating bids dataset %s: $s", dataset, str(e))
|
|
387
|
+
raise e
|
|
388
|
+
requests = []
|
|
389
|
+
for bids_file in bids_dataset.get_files():
|
|
390
|
+
try:
|
|
391
|
+
data_id = f"{dataset}_{Path(bids_file).name}"
|
|
392
|
+
|
|
393
|
+
if self.exist({"data_name": data_id}):
|
|
394
|
+
if overwrite:
|
|
395
|
+
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
396
|
+
bids_dataset, bids_file
|
|
397
|
+
)
|
|
398
|
+
requests.append(self.update_request(eeg_attrs))
|
|
399
|
+
else:
|
|
400
|
+
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
401
|
+
bids_dataset, bids_file
|
|
402
|
+
)
|
|
403
|
+
requests.append(self.add_request(eeg_attrs))
|
|
404
|
+
except Exception as e:
|
|
405
|
+
logger.error("Error adding record %s", bids_file)
|
|
406
|
+
logger.error(str(e))
|
|
407
|
+
|
|
408
|
+
logger.info("Number of requests: %s", len(requests))
|
|
409
|
+
|
|
410
|
+
if requests:
|
|
411
|
+
result = self.__collection.bulk_write(requests, ordered=False)
|
|
412
|
+
logger.info("Inserted: %s ", result.inserted_count)
|
|
413
|
+
logger.info("Modified: %s ", result.modified_count)
|
|
414
|
+
logger.info("Deleted: %s", result.deleted_count)
|
|
415
|
+
logger.info("Upserted: %s", result.upserted_count)
|
|
416
|
+
logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
|
|
417
|
+
|
|
418
|
+
def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
|
|
419
|
+
"""Retrieve a list of EEG data arrays that match the given query. See also
|
|
420
|
+
the `find()` method for details on the query format.
|
|
421
|
+
|
|
422
|
+
Parameters
|
|
423
|
+
----------
|
|
424
|
+
query : dict
|
|
425
|
+
A dictionary that specifies the query to be executed; this is a reference
|
|
426
|
+
document that is used to match records in the MongoDB collection.
|
|
427
|
+
|
|
428
|
+
Returns
|
|
429
|
+
-------
|
|
430
|
+
A list of xarray DataArray objects containing the EEG data for each matching record.
|
|
431
|
+
|
|
432
|
+
Notes
|
|
433
|
+
-----
|
|
434
|
+
Retrieval is done in parallel, and the downloaded data are not cached locally.
|
|
435
|
+
|
|
436
|
+
"""
|
|
437
|
+
sessions = self.find(query)
|
|
438
|
+
results = []
|
|
439
|
+
if sessions:
|
|
440
|
+
logger.info("Found %s records", len(sessions))
|
|
441
|
+
results = Parallel(
|
|
442
|
+
n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
|
|
443
|
+
)(
|
|
444
|
+
delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
|
|
445
|
+
for session in sessions
|
|
446
|
+
)
|
|
447
|
+
return results
|
|
448
|
+
|
|
449
|
+
def add_request(self, record: dict):
|
|
450
|
+
"""Internal helper method to create a MongoDB insertion request for a record."""
|
|
451
|
+
return InsertOne(record)
|
|
452
|
+
|
|
453
|
+
def add(self, record: dict):
|
|
454
|
+
"""Add a single record to the MongoDB collection."""
|
|
455
|
+
try:
|
|
456
|
+
self.__collection.insert_one(record)
|
|
457
|
+
except ValueError as e:
|
|
458
|
+
logger.error("Validation error for record: %s ", record["data_name"])
|
|
459
|
+
logger.error(e)
|
|
460
|
+
except:
|
|
461
|
+
logger.error("Error adding record: %s ", record["data_name"])
|
|
462
|
+
|
|
463
|
+
def update_request(self, record: dict):
|
|
464
|
+
"""Internal helper method to create a MongoDB update request for a record."""
|
|
465
|
+
return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
|
|
466
|
+
|
|
467
|
+
def update(self, record: dict):
|
|
468
|
+
"""Update a single record in the MongoDB collection."""
|
|
469
|
+
try:
|
|
470
|
+
self.__collection.update_one(
|
|
471
|
+
{"data_name": record["data_name"]}, {"$set": record}
|
|
472
|
+
)
|
|
473
|
+
except: # silent failure
|
|
474
|
+
logger.error("Error updating record: %s", record["data_name"])
|
|
475
|
+
|
|
476
|
+
def remove_field(self, record, field):
|
|
477
|
+
"""Remove a specific field from a record in the MongoDB collection."""
|
|
478
|
+
self.__collection.update_one(
|
|
479
|
+
{"data_name": record["data_name"]}, {"$unset": {field: 1}}
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def remove_field_from_db(self, field):
|
|
483
|
+
"""Removed all occurrences of a specific field from all records in the MongoDB
|
|
484
|
+
collection. WARNING: this operation is destructive and should be used with caution.
|
|
485
|
+
"""
|
|
486
|
+
self.__collection.update_many({}, {"$unset": {field: 1}})
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def collection(self):
|
|
490
|
+
"""Return the MongoDB collection object."""
|
|
491
|
+
return self.__collection
|
|
492
|
+
|
|
493
|
+
def close(self):
|
|
494
|
+
"""Close the MongoDB client connection."""
|
|
495
|
+
if hasattr(self, "_EEGDash__client"):
|
|
496
|
+
self.__client.close()
|
|
497
|
+
|
|
498
|
+
def __del__(self):
|
|
499
|
+
"""Ensure connection is closed when object is deleted."""
|
|
500
|
+
self.close()
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class EEGDashDataset(BaseConcatDataset):
|
|
504
|
+
def __init__(
|
|
505
|
+
self,
|
|
506
|
+
query: dict | None = None,
|
|
507
|
+
data_dir: str | list | None = None,
|
|
508
|
+
dataset: str | list | None = None,
|
|
509
|
+
description_fields: list[str] = [
|
|
510
|
+
"subject",
|
|
511
|
+
"session",
|
|
512
|
+
"run",
|
|
513
|
+
"task",
|
|
514
|
+
"age",
|
|
515
|
+
"gender",
|
|
516
|
+
"sex",
|
|
517
|
+
],
|
|
518
|
+
cache_dir: str = ".eegdash_cache",
|
|
519
|
+
s3_bucket: str | None = None,
|
|
520
|
+
**kwargs,
|
|
521
|
+
):
|
|
522
|
+
"""Create a new EEGDashDataset from a given query or local BIDS dataset directory
|
|
523
|
+
and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
|
|
524
|
+
instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
|
|
525
|
+
|
|
526
|
+
Parameters
|
|
527
|
+
----------
|
|
528
|
+
query : dict | None
|
|
529
|
+
Optionally a dictionary that specifies the query to be executed; see
|
|
530
|
+
EEGDash.find() for details on the query format.
|
|
531
|
+
data_dir : str | list[str] | None
|
|
532
|
+
Optionally a string or a list of strings specifying one or more local
|
|
533
|
+
BIDS dataset directories from which to load the EEG data files. Exactly one
|
|
534
|
+
of query or data_dir must be provided.
|
|
535
|
+
dataset : str | list[str] | None
|
|
536
|
+
If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
|
|
537
|
+
description_fields : list[str]
|
|
538
|
+
A list of fields to be extracted from the dataset records
|
|
539
|
+
and included in the returned data description(s). Examples are typical
|
|
540
|
+
subject metadata fields such as "subject", "session", "run", "task", etc.;
|
|
541
|
+
see also data_config.description_fields for the default set of fields.
|
|
542
|
+
cache_dir : str
|
|
543
|
+
A directory where the dataset will be cached locally.
|
|
544
|
+
s3_bucket : str | None
|
|
545
|
+
An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
|
|
546
|
+
default OpenNeuro bucket for loading data files
|
|
547
|
+
kwargs : dict
|
|
548
|
+
Additional keyword arguments to be passed to the EEGDashBaseDataset
|
|
549
|
+
constructor.
|
|
550
|
+
|
|
551
|
+
"""
|
|
552
|
+
self.cache_dir = cache_dir
|
|
553
|
+
self.s3_bucket = s3_bucket
|
|
554
|
+
if query:
|
|
555
|
+
datasets = self.find_datasets(query, description_fields, **kwargs)
|
|
556
|
+
elif data_dir:
|
|
557
|
+
if isinstance(data_dir, str):
|
|
558
|
+
datasets = self.load_bids_dataset(
|
|
559
|
+
dataset, data_dir, description_fields, s3_bucket
|
|
560
|
+
)
|
|
561
|
+
else:
|
|
562
|
+
assert len(data_dir) == len(dataset), (
|
|
563
|
+
"Number of datasets and their directories must match"
|
|
564
|
+
)
|
|
565
|
+
datasets = []
|
|
566
|
+
for i, _ in enumerate(data_dir):
|
|
567
|
+
datasets.extend(
|
|
568
|
+
self.load_bids_dataset(
|
|
569
|
+
dataset[i], data_dir[i], description_fields, s3_bucket
|
|
570
|
+
)
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
super().__init__(datasets)
|
|
574
|
+
|
|
575
|
+
def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
|
|
576
|
+
"""Helper to recursively search for a key in a nested dictionary structure; returns
|
|
577
|
+
the value associated with the first occurrence of the key, or None if not found.
|
|
578
|
+
"""
|
|
579
|
+
if isinstance(data, dict):
|
|
580
|
+
if target_key in data:
|
|
581
|
+
return data[target_key]
|
|
582
|
+
for value in data.values():
|
|
583
|
+
result = self.find_key_in_nested_dict(value, target_key)
|
|
584
|
+
if result is not None:
|
|
585
|
+
return result
|
|
586
|
+
return None
|
|
587
|
+
|
|
588
|
+
def find_datasets(
|
|
589
|
+
self, query: dict[str, Any], description_fields: list[str], **kwargs
|
|
590
|
+
) -> list[EEGDashBaseDataset]:
|
|
591
|
+
"""Helper method to find datasets in the MongoDB collection that satisfy the
|
|
592
|
+
given query and return them as a list of EEGDashBaseDataset objects.
|
|
593
|
+
|
|
594
|
+
Parameters
|
|
595
|
+
----------
|
|
596
|
+
query : dict
|
|
597
|
+
The query object, as in EEGDash.find().
|
|
598
|
+
description_fields : list[str]
|
|
599
|
+
A list of fields to be extracted from the dataset records and included in
|
|
600
|
+
the returned dataset description(s).
|
|
601
|
+
kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
|
|
602
|
+
constructor.
|
|
603
|
+
|
|
604
|
+
Returns
|
|
605
|
+
-------
|
|
606
|
+
list :
|
|
607
|
+
A list of EEGDashBaseDataset objects that match the query.
|
|
608
|
+
|
|
609
|
+
"""
|
|
610
|
+
eeg_dash_instance = EEGDash()
|
|
611
|
+
try:
|
|
612
|
+
datasets = []
|
|
613
|
+
for record in eeg_dash_instance.find(query):
|
|
614
|
+
description = {}
|
|
615
|
+
for field in description_fields:
|
|
616
|
+
value = self.find_key_in_nested_dict(record, field)
|
|
617
|
+
if value is not None:
|
|
618
|
+
description[field] = value
|
|
619
|
+
datasets.append(
|
|
620
|
+
EEGDashBaseDataset(
|
|
621
|
+
record,
|
|
622
|
+
self.cache_dir,
|
|
623
|
+
self.s3_bucket,
|
|
624
|
+
description=description,
|
|
625
|
+
**kwargs,
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
return datasets
|
|
629
|
+
finally:
|
|
630
|
+
eeg_dash_instance.close()
|
|
631
|
+
|
|
632
|
+
def load_bids_dataset(
|
|
633
|
+
self,
|
|
634
|
+
dataset,
|
|
635
|
+
data_dir,
|
|
636
|
+
description_fields: list[str],
|
|
637
|
+
s3_bucket: str | None = None,
|
|
638
|
+
**kwargs,
|
|
639
|
+
):
|
|
640
|
+
"""Helper method to load a single local BIDS dataset and return it as a list of
|
|
641
|
+
EEGDashBaseDatasets (one for each recording in the dataset).
|
|
642
|
+
|
|
643
|
+
Parameters
|
|
644
|
+
----------
|
|
645
|
+
dataset : str
|
|
646
|
+
A name for the dataset to be loaded (e.g., "ds002718").
|
|
647
|
+
data_dir : str
|
|
648
|
+
The path to the local BIDS dataset directory.
|
|
649
|
+
description_fields : list[str]
|
|
650
|
+
A list of fields to be extracted from the dataset records
|
|
651
|
+
and included in the returned dataset description(s).
|
|
652
|
+
|
|
653
|
+
"""
|
|
654
|
+
|
|
655
|
+
def get_base_dataset_from_bids_file(
|
|
656
|
+
bids_dataset: EEGBIDSDataset,
|
|
657
|
+
bids_file: str,
|
|
658
|
+
eeg_dash_instance: EEGDash,
|
|
659
|
+
s3_bucket: str | None,
|
|
660
|
+
) -> EEGDashBaseDataset:
|
|
661
|
+
"""Instantiate a single EEGDashBaseDataset given a local BIDS file. Note
|
|
662
|
+
this does not actually load the data from disk, but will access the metadata.
|
|
663
|
+
"""
|
|
664
|
+
record = eeg_dash_instance.load_eeg_attrs_from_bids_file(
|
|
665
|
+
bids_dataset, bids_file
|
|
666
|
+
)
|
|
667
|
+
description = {}
|
|
668
|
+
for field in description_fields:
|
|
669
|
+
value = self.find_key_in_nested_dict(record, field)
|
|
670
|
+
if value is not None:
|
|
671
|
+
description[field] = value
|
|
672
|
+
return EEGDashBaseDataset(
|
|
673
|
+
record, self.cache_dir, s3_bucket, description=description, **kwargs
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
bids_dataset = EEGBIDSDataset(
|
|
677
|
+
data_dir=data_dir,
|
|
678
|
+
dataset=dataset,
|
|
679
|
+
)
|
|
680
|
+
eeg_dash_instance = EEGDash()
|
|
681
|
+
try:
|
|
682
|
+
datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
|
|
683
|
+
delayed(get_base_dataset_from_bids_file)(
|
|
684
|
+
bids_dataset, bids_file, eeg_dash_instance, s3_bucket
|
|
685
|
+
)
|
|
686
|
+
for bids_file in bids_dataset.get_files()
|
|
687
|
+
)
|
|
688
|
+
return datasets
|
|
689
|
+
finally:
|
|
690
|
+
eeg_dash_instance.close()
|