eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of eegdash might be problematic. Click here for more details.
- eegdash/__init__.py +4 -1
- eegdash/data_config.py +28 -0
- eegdash/data_utils.py +193 -148
- eegdash/features/__init__.py +25 -0
- eegdash/features/datasets.py +456 -0
- eegdash/features/decorators.py +43 -0
- eegdash/features/extractors.py +210 -0
- eegdash/features/feature_bank/__init__.py +6 -0
- eegdash/features/feature_bank/complexity.py +96 -0
- eegdash/features/feature_bank/connectivity.py +59 -0
- eegdash/features/feature_bank/csp.py +101 -0
- eegdash/features/feature_bank/dimensionality.py +107 -0
- eegdash/features/feature_bank/signal.py +103 -0
- eegdash/features/feature_bank/spectral.py +116 -0
- eegdash/features/feature_bank/utils.py +48 -0
- eegdash/features/serialization.py +87 -0
- eegdash/features/utils.py +116 -0
- eegdash/main.py +250 -145
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/METADATA +26 -56
- eegdash-0.1.0.dist-info/RECORD +23 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/WHEEL +1 -1
- eegdash-0.0.8.dist-info/RECORD +0 -8
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {eegdash-0.0.8.dist-info → eegdash-0.1.0.dist-info}/top_level.txt +0 -0
eegdash/main.py
CHANGED
|
@@ -1,69 +1,81 @@
|
|
|
1
|
-
|
|
2
|
-
import pymongo
|
|
3
|
-
from dotenv import load_dotenv
|
|
1
|
+
import json
|
|
4
2
|
import os
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
import s3fs
|
|
7
|
-
from joblib import Parallel, delayed
|
|
8
3
|
import tempfile
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
9
7
|
import mne
|
|
10
8
|
import numpy as np
|
|
9
|
+
import pymongo
|
|
10
|
+
import s3fs
|
|
11
11
|
import xarray as xr
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
|
|
12
|
+
from dotenv import load_dotenv
|
|
13
|
+
from joblib import Parallel, delayed
|
|
14
|
+
from pymongo import DeleteOne, InsertOne, MongoClient, UpdateOne
|
|
15
|
+
|
|
16
|
+
from braindecode.datasets import BaseConcatDataset, BaseDataset
|
|
17
|
+
|
|
18
|
+
from .data_config import config as data_config
|
|
19
|
+
from .data_utils import EEGBIDSDataset, EEGDashBaseDataset, EEGDashBaseRaw
|
|
20
|
+
|
|
16
21
|
|
|
17
22
|
class EEGDash:
|
|
18
|
-
AWS_BUCKET =
|
|
19
|
-
|
|
20
|
-
|
|
23
|
+
AWS_BUCKET = "s3://openneuro.org"
|
|
24
|
+
|
|
25
|
+
def __init__(self, is_public=True):
|
|
26
|
+
# Load config file
|
|
27
|
+
# config_path = Path(__file__).parent / 'config.json'
|
|
28
|
+
# with open(config_path, 'r') as f:
|
|
29
|
+
# self.config = json.load(f)
|
|
30
|
+
|
|
31
|
+
self.config = data_config
|
|
21
32
|
if is_public:
|
|
22
|
-
DB_CONNECTION_STRING="mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
|
|
33
|
+
DB_CONNECTION_STRING = "mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
|
|
23
34
|
else:
|
|
24
35
|
load_dotenv()
|
|
25
|
-
DB_CONNECTION_STRING = os.getenv(
|
|
36
|
+
DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
|
|
26
37
|
|
|
27
38
|
self.__client = pymongo.MongoClient(DB_CONNECTION_STRING)
|
|
28
|
-
self.__db = self.__client[
|
|
29
|
-
self.__collection = self.__db[
|
|
39
|
+
self.__db = self.__client["eegdash"]
|
|
40
|
+
self.__collection = self.__db["records"]
|
|
30
41
|
|
|
31
42
|
self.is_public = is_public
|
|
32
|
-
self.filesystem = s3fs.S3FileSystem(
|
|
33
|
-
|
|
43
|
+
self.filesystem = s3fs.S3FileSystem(
|
|
44
|
+
anon=True, client_kwargs={"region_name": "us-east-2"}
|
|
45
|
+
)
|
|
46
|
+
|
|
34
47
|
def find(self, *args):
|
|
35
48
|
results = self.__collection.find(*args)
|
|
36
|
-
|
|
49
|
+
|
|
37
50
|
# convert to list using get_item on each element
|
|
38
51
|
return [result for result in results]
|
|
39
52
|
|
|
40
|
-
def exist(self,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
}
|
|
53
|
+
def exist(self, query: dict):
|
|
54
|
+
accepted_query_fields = ["data_name", "dataset"]
|
|
55
|
+
assert all(field in accepted_query_fields for field in query.keys())
|
|
44
56
|
sessions = self.find(query)
|
|
45
57
|
return len(sessions) > 0
|
|
46
58
|
|
|
47
|
-
def _validate_input(self, record:dict):
|
|
59
|
+
def _validate_input(self, record: dict):
|
|
48
60
|
input_types = {
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
61
|
+
"data_name": str,
|
|
62
|
+
"dataset": str,
|
|
63
|
+
"bidspath": str,
|
|
64
|
+
"subject": str,
|
|
65
|
+
"task": str,
|
|
66
|
+
"session": str,
|
|
67
|
+
"run": str,
|
|
68
|
+
"sampling_frequency": float,
|
|
69
|
+
"modality": str,
|
|
70
|
+
"nchans": int,
|
|
71
|
+
"ntimes": int,
|
|
72
|
+
"channel_types": list,
|
|
73
|
+
"channel_names": list,
|
|
62
74
|
}
|
|
63
|
-
if
|
|
75
|
+
if "data_name" not in record:
|
|
64
76
|
raise ValueError("Missing key: data_name")
|
|
65
77
|
# check if args are in the keys and has correct type
|
|
66
|
-
for key,value in record.items():
|
|
78
|
+
for key, value in record.items():
|
|
67
79
|
if key not in input_types:
|
|
68
80
|
raise ValueError(f"Invalid input: {key}")
|
|
69
81
|
if not isinstance(value, input_types[key]):
|
|
@@ -72,7 +84,7 @@ class EEGDash:
|
|
|
72
84
|
return record
|
|
73
85
|
|
|
74
86
|
def load_eeg_data_from_s3(self, s3path):
|
|
75
|
-
with tempfile.NamedTemporaryFile(delete=False, suffix=
|
|
87
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
|
|
76
88
|
with self.filesystem.open(s3path) as s3_file:
|
|
77
89
|
tmp.write(s3_file.read())
|
|
78
90
|
tmp_path = tmp.name
|
|
@@ -80,97 +92,157 @@ class EEGDash:
|
|
|
80
92
|
os.unlink(tmp_path)
|
|
81
93
|
return eeg_data
|
|
82
94
|
|
|
83
|
-
def load_eeg_data_from_bids_file(self,
|
|
84
|
-
|
|
95
|
+
def load_eeg_data_from_bids_file(self, bids_file, eeg_attrs=None):
|
|
96
|
+
"""
|
|
85
97
|
bids_file must be a file of the bids_dataset
|
|
86
|
-
|
|
98
|
+
"""
|
|
87
99
|
EEG = mne.io.read_raw_eeglab(bids_file)
|
|
88
100
|
eeg_data = EEG.get_data()
|
|
89
|
-
|
|
90
|
-
fs = EEG.info[
|
|
101
|
+
|
|
102
|
+
fs = EEG.info["sfreq"]
|
|
91
103
|
max_time = eeg_data.shape[1] / fs
|
|
92
|
-
time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze()
|
|
104
|
+
time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
|
|
93
105
|
|
|
94
106
|
channel_names = EEG.ch_names
|
|
95
107
|
|
|
96
108
|
eeg_xarray = xr.DataArray(
|
|
97
109
|
data=eeg_data,
|
|
98
|
-
dims=[
|
|
99
|
-
coords={
|
|
100
|
-
'time': time_steps,
|
|
101
|
-
'channel': channel_names
|
|
102
|
-
},
|
|
110
|
+
dims=["channel", "time"],
|
|
111
|
+
coords={"time": time_steps, "channel": channel_names},
|
|
103
112
|
# attrs=attrs
|
|
104
113
|
)
|
|
105
114
|
return eeg_xarray
|
|
106
115
|
|
|
107
|
-
def
|
|
108
|
-
|
|
116
|
+
def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
|
|
117
|
+
bids_file = Path(bids_file)
|
|
118
|
+
extensions = {
|
|
119
|
+
".set": [".set", ".fdt"], # eeglab
|
|
120
|
+
".edf": [".edf"], # european
|
|
121
|
+
".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
|
|
122
|
+
".bdf": [".bdf"], # biosemi
|
|
123
|
+
}
|
|
124
|
+
return [
|
|
125
|
+
str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
|
|
126
|
+
for suffix in extensions[bids_file.suffix]
|
|
127
|
+
if bids_file.with_suffix(suffix).exists()
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
|
|
131
|
+
"""
|
|
109
132
|
bids_file must be a file of the bids_dataset
|
|
110
|
-
|
|
133
|
+
"""
|
|
111
134
|
if bids_file not in bids_dataset.files:
|
|
112
|
-
raise ValueError(f
|
|
135
|
+
raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
|
|
136
|
+
|
|
137
|
+
# Initialize attrs with None values for all expected fields
|
|
138
|
+
attrs = {field: None for field in self.config["attributes"].keys()}
|
|
139
|
+
|
|
113
140
|
f = os.path.basename(bids_file)
|
|
114
141
|
dsnumber = bids_dataset.dataset
|
|
115
142
|
# extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
|
|
116
143
|
openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
|
|
117
144
|
|
|
118
|
-
|
|
145
|
+
# Update with actual values where available
|
|
146
|
+
try:
|
|
147
|
+
participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
print(f"Error getting participants_tsv: {str(e)}")
|
|
150
|
+
participants_tsv = None
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
eeg_json = bids_dataset.eeg_json(bids_file)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
print(f"Error getting eeg_json: {str(e)}")
|
|
156
|
+
eeg_json = None
|
|
157
|
+
|
|
158
|
+
bids_dependencies_files = self.config["bids_dependencies_files"]
|
|
119
159
|
bidsdependencies = []
|
|
120
160
|
for extension in bids_dependencies_files:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
161
|
+
try:
|
|
162
|
+
dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
|
|
163
|
+
dep_path = [
|
|
164
|
+
str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
|
|
165
|
+
]
|
|
166
|
+
bidsdependencies.extend(dep_path)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
|
|
171
|
+
|
|
172
|
+
# Define field extraction functions with error handling
|
|
173
|
+
field_extractors = {
|
|
174
|
+
"data_name": lambda: f"{bids_dataset.dataset}_{f}",
|
|
175
|
+
"dataset": lambda: bids_dataset.dataset,
|
|
176
|
+
"bidspath": lambda: openneuro_path,
|
|
177
|
+
"subject": lambda: bids_dataset.get_bids_file_attribute(
|
|
178
|
+
"subject", bids_file
|
|
179
|
+
),
|
|
180
|
+
"task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
|
|
181
|
+
"session": lambda: bids_dataset.get_bids_file_attribute(
|
|
182
|
+
"session", bids_file
|
|
183
|
+
),
|
|
184
|
+
"run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
|
|
185
|
+
"modality": lambda: bids_dataset.get_bids_file_attribute(
|
|
186
|
+
"modality", bids_file
|
|
187
|
+
),
|
|
188
|
+
"sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
|
|
189
|
+
"sfreq", bids_file
|
|
190
|
+
),
|
|
191
|
+
"nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
|
|
192
|
+
"ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
|
|
193
|
+
"participant_tsv": lambda: participants_tsv,
|
|
194
|
+
"eeg_json": lambda: eeg_json,
|
|
195
|
+
"bidsdependencies": lambda: bidsdependencies,
|
|
142
196
|
}
|
|
143
197
|
|
|
198
|
+
# Dynamically populate attrs with error handling
|
|
199
|
+
for field, extractor in field_extractors.items():
|
|
200
|
+
try:
|
|
201
|
+
attrs[field] = extractor()
|
|
202
|
+
except Exception as e:
|
|
203
|
+
print(f"Error extracting {field}: {str(e)}")
|
|
204
|
+
attrs[field] = None
|
|
205
|
+
|
|
144
206
|
return attrs
|
|
145
207
|
|
|
146
|
-
def add_bids_dataset(self, dataset, data_dir,
|
|
147
|
-
|
|
208
|
+
def add_bids_dataset(self, dataset, data_dir, overwrite=True):
|
|
209
|
+
"""
|
|
148
210
|
Create new records for the dataset in the MongoDB database if not found
|
|
149
|
-
|
|
211
|
+
"""
|
|
150
212
|
if self.is_public:
|
|
151
|
-
raise ValueError(
|
|
213
|
+
raise ValueError("This operation is not allowed for public users")
|
|
152
214
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
215
|
+
if not overwrite and self.exist({"dataset": dataset}):
|
|
216
|
+
print(f"Dataset {dataset} already exists in the database")
|
|
217
|
+
return
|
|
218
|
+
try:
|
|
219
|
+
bids_dataset = EEGBIDSDataset(
|
|
220
|
+
data_dir=data_dir,
|
|
221
|
+
dataset=dataset,
|
|
222
|
+
)
|
|
223
|
+
except Exception as e:
|
|
224
|
+
print(f"Error creating bids dataset {dataset}: {str(e)}")
|
|
225
|
+
raise e
|
|
158
226
|
requests = []
|
|
159
227
|
for bids_file in bids_dataset.get_files():
|
|
160
228
|
try:
|
|
161
229
|
data_id = f"{dataset}_{os.path.basename(bids_file)}"
|
|
162
230
|
|
|
163
|
-
if self.exist(data_name
|
|
231
|
+
if self.exist({"data_name": data_id}):
|
|
164
232
|
if overwrite:
|
|
165
|
-
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
166
|
-
|
|
233
|
+
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
234
|
+
bids_dataset, bids_file
|
|
235
|
+
)
|
|
236
|
+
requests.append(self.update_request(eeg_attrs))
|
|
167
237
|
else:
|
|
168
|
-
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
238
|
+
eeg_attrs = self.load_eeg_attrs_from_bids_file(
|
|
239
|
+
bids_dataset, bids_file
|
|
240
|
+
)
|
|
169
241
|
requests.append(self.add_request(eeg_attrs))
|
|
170
242
|
except:
|
|
171
|
-
print(
|
|
243
|
+
print("error adding record", bids_file)
|
|
172
244
|
|
|
173
|
-
print(
|
|
245
|
+
print("Number of database requests", len(requests))
|
|
174
246
|
|
|
175
247
|
if requests:
|
|
176
248
|
result = self.__collection.bulk_write(requests, ordered=False)
|
|
@@ -180,25 +252,28 @@ class EEGDash:
|
|
|
180
252
|
print(f"Upserted: {result.upserted_count}")
|
|
181
253
|
print(f"Errors: {result.bulk_api_result.get('writeErrors', [])}")
|
|
182
254
|
|
|
183
|
-
def get(self, query:dict):
|
|
184
|
-
|
|
255
|
+
def get(self, query: dict):
|
|
256
|
+
"""
|
|
185
257
|
query: {
|
|
186
258
|
'dataset': 'dsxxxx',
|
|
187
259
|
|
|
188
|
-
}
|
|
260
|
+
}"""
|
|
189
261
|
sessions = self.find(query)
|
|
190
262
|
results = []
|
|
191
263
|
if sessions:
|
|
192
|
-
print(f
|
|
193
|
-
results = Parallel(
|
|
194
|
-
|
|
264
|
+
print(f"Found {len(sessions)} records")
|
|
265
|
+
results = Parallel(
|
|
266
|
+
n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
|
|
267
|
+
)(
|
|
268
|
+
delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
|
|
269
|
+
for session in sessions
|
|
195
270
|
)
|
|
196
271
|
return results
|
|
197
272
|
|
|
198
|
-
def add_request(self, record:dict):
|
|
273
|
+
def add_request(self, record: dict):
|
|
199
274
|
return InsertOne(record)
|
|
200
275
|
|
|
201
|
-
def add(self, record:dict):
|
|
276
|
+
def add(self, record: dict):
|
|
202
277
|
try:
|
|
203
278
|
# input_record = self._validate_input(record)
|
|
204
279
|
self.__collection.insert_one(record)
|
|
@@ -206,48 +281,72 @@ class EEGDash:
|
|
|
206
281
|
except ValueError as e:
|
|
207
282
|
print(f"Failed to validate record: {record['data_name']}")
|
|
208
283
|
print(e)
|
|
209
|
-
except:
|
|
284
|
+
except:
|
|
210
285
|
print(f"Error adding record: {record['data_name']}")
|
|
211
286
|
|
|
212
|
-
def update_request(self, record:dict):
|
|
213
|
-
return UpdateOne({
|
|
287
|
+
def update_request(self, record: dict):
|
|
288
|
+
return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
|
|
214
289
|
|
|
215
|
-
def update(self, record:dict):
|
|
290
|
+
def update(self, record: dict):
|
|
216
291
|
try:
|
|
217
|
-
self.__collection.update_one(
|
|
218
|
-
|
|
219
|
-
|
|
292
|
+
self.__collection.update_one(
|
|
293
|
+
{"data_name": record["data_name"]}, {"$set": record}
|
|
294
|
+
)
|
|
295
|
+
except: # silent failure
|
|
296
|
+
print(f"Error updating record {record['data_name']}")
|
|
220
297
|
|
|
221
298
|
def remove_field(self, record, field):
|
|
222
|
-
self.__collection.update_one(
|
|
223
|
-
|
|
299
|
+
self.__collection.update_one(
|
|
300
|
+
{"data_name": record["data_name"]}, {"$unset": {field: 1}}
|
|
301
|
+
)
|
|
302
|
+
|
|
224
303
|
def remove_field_from_db(self, field):
|
|
225
|
-
self.__collection.update_many({}, {
|
|
226
|
-
|
|
304
|
+
self.__collection.update_many({}, {"$unset": {field: 1}})
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def collection(self):
|
|
308
|
+
return self.__collection
|
|
309
|
+
|
|
227
310
|
|
|
228
311
|
class EEGDashDataset(BaseConcatDataset):
|
|
229
|
-
CACHE_DIR = '.eegdash_cache'
|
|
312
|
+
# CACHE_DIR = '.eegdash_cache'
|
|
230
313
|
def __init__(
|
|
231
314
|
self,
|
|
232
|
-
query:dict=None,
|
|
233
|
-
data_dir:str | list =None,
|
|
234
|
-
dataset:str | list =None,
|
|
235
|
-
description_fields: list[str]=[
|
|
236
|
-
|
|
315
|
+
query: dict = None,
|
|
316
|
+
data_dir: str | list = None,
|
|
317
|
+
dataset: str | list = None,
|
|
318
|
+
description_fields: list[str] = [
|
|
319
|
+
"subject",
|
|
320
|
+
"session",
|
|
321
|
+
"run",
|
|
322
|
+
"task",
|
|
323
|
+
"age",
|
|
324
|
+
"gender",
|
|
325
|
+
"sex",
|
|
326
|
+
],
|
|
327
|
+
cache_dir: str = ".eegdash_cache",
|
|
328
|
+
**kwargs,
|
|
237
329
|
):
|
|
330
|
+
self.cache_dir = cache_dir
|
|
238
331
|
if query:
|
|
239
332
|
datasets = self.find_datasets(query, description_fields, **kwargs)
|
|
240
333
|
elif data_dir:
|
|
241
334
|
if type(data_dir) == str:
|
|
242
335
|
datasets = self.load_bids_dataset(dataset, data_dir, description_fields)
|
|
243
336
|
else:
|
|
244
|
-
assert len(data_dir) == len(dataset),
|
|
337
|
+
assert len(data_dir) == len(dataset), (
|
|
338
|
+
"Number of datasets and their directories must match"
|
|
339
|
+
)
|
|
245
340
|
datasets = []
|
|
246
341
|
for i in range(len(data_dir)):
|
|
247
|
-
datasets.extend(
|
|
342
|
+
datasets.extend(
|
|
343
|
+
self.load_bids_dataset(
|
|
344
|
+
dataset[i], data_dir[i], description_fields
|
|
345
|
+
)
|
|
346
|
+
)
|
|
248
347
|
# convert to list using get_item on each element
|
|
249
348
|
super().__init__(datasets)
|
|
250
|
-
|
|
349
|
+
|
|
251
350
|
def find_key_in_nested_dict(self, data, target_key):
|
|
252
351
|
if isinstance(data, dict):
|
|
253
352
|
if target_key in data:
|
|
@@ -258,54 +357,60 @@ class EEGDashDataset(BaseConcatDataset):
|
|
|
258
357
|
return result
|
|
259
358
|
return None
|
|
260
359
|
|
|
261
|
-
def find_datasets(self, query:dict, description_fields:list[str], **kwargs):
|
|
360
|
+
def find_datasets(self, query: dict, description_fields: list[str], **kwargs):
|
|
262
361
|
eegdashObj = EEGDash()
|
|
263
362
|
datasets = []
|
|
264
363
|
for record in eegdashObj.find(query):
|
|
265
364
|
description = {}
|
|
266
365
|
for field in description_fields:
|
|
267
366
|
value = self.find_key_in_nested_dict(record, field)
|
|
268
|
-
if value:
|
|
367
|
+
if value is not None:
|
|
269
368
|
description[field] = value
|
|
270
|
-
datasets.append(
|
|
369
|
+
datasets.append(
|
|
370
|
+
EEGDashBaseDataset(
|
|
371
|
+
record, self.cache_dir, description=description, **kwargs
|
|
372
|
+
)
|
|
373
|
+
)
|
|
271
374
|
return datasets
|
|
272
375
|
|
|
273
|
-
def load_bids_dataset(
|
|
274
|
-
|
|
275
|
-
|
|
376
|
+
def load_bids_dataset(
|
|
377
|
+
self,
|
|
378
|
+
dataset,
|
|
379
|
+
data_dir,
|
|
380
|
+
description_fields: list[str],
|
|
381
|
+
raw_format="eeglab",
|
|
382
|
+
**kwargs,
|
|
383
|
+
):
|
|
384
|
+
""" """
|
|
385
|
+
|
|
276
386
|
def get_base_dataset_from_bids_file(bids_dataset, bids_file):
|
|
277
387
|
record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
278
388
|
description = {}
|
|
279
389
|
for field in description_fields:
|
|
280
390
|
value = self.find_key_in_nested_dict(record, field)
|
|
281
|
-
if value:
|
|
391
|
+
if value is not None:
|
|
282
392
|
description[field] = value
|
|
283
|
-
return EEGDashBaseDataset(
|
|
393
|
+
return EEGDashBaseDataset(
|
|
394
|
+
record, self.cache_dir, description=description, **kwargs
|
|
395
|
+
)
|
|
284
396
|
|
|
285
|
-
bids_dataset =
|
|
397
|
+
bids_dataset = EEGBIDSDataset(
|
|
286
398
|
data_dir=data_dir,
|
|
287
399
|
dataset=dataset,
|
|
288
|
-
raw_format=raw_format,
|
|
289
400
|
)
|
|
290
401
|
eegdashObj = EEGDash()
|
|
291
402
|
datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
|
|
292
|
-
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
# for bids_file in bids_dataset.get_files():
|
|
296
|
-
# record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
|
|
297
|
-
# description = {}
|
|
298
|
-
# for field in description_fields:
|
|
299
|
-
# value = self.find_key_in_nested_dict(record, field)
|
|
300
|
-
# if value:
|
|
301
|
-
# description[field] = value
|
|
302
|
-
# datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
|
|
403
|
+
delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file)
|
|
404
|
+
for bids_file in bids_dataset.get_files()
|
|
405
|
+
)
|
|
303
406
|
return datasets
|
|
304
407
|
|
|
408
|
+
|
|
305
409
|
def main():
|
|
306
410
|
eegdash = EEGDash()
|
|
307
|
-
record = eegdash.find({
|
|
411
|
+
record = eegdash.find({"dataset": "ds005511", "subject": "NDARUF236HM7"})
|
|
308
412
|
print(record)
|
|
309
413
|
|
|
310
|
-
|
|
311
|
-
|
|
414
|
+
|
|
415
|
+
if __name__ == "__main__":
|
|
416
|
+
main()
|