eegdash 0.1.0__py3-none-any.whl → 0.2.1.dev178237806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/main.py DELETED
@@ -1,416 +0,0 @@
1
- import json
2
- import os
3
- import tempfile
4
- from collections import defaultdict
5
- from pathlib import Path
6
-
7
- import mne
8
- import numpy as np
9
- import pymongo
10
- import s3fs
11
- import xarray as xr
12
- from dotenv import load_dotenv
13
- from joblib import Parallel, delayed
14
- from pymongo import DeleteOne, InsertOne, MongoClient, UpdateOne
15
-
16
- from braindecode.datasets import BaseConcatDataset, BaseDataset
17
-
18
- from .data_config import config as data_config
19
- from .data_utils import EEGBIDSDataset, EEGDashBaseDataset, EEGDashBaseRaw
20
-
21
-
22
- class EEGDash:
23
- AWS_BUCKET = "s3://openneuro.org"
24
-
25
- def __init__(self, is_public=True):
26
- # Load config file
27
- # config_path = Path(__file__).parent / 'config.json'
28
- # with open(config_path, 'r') as f:
29
- # self.config = json.load(f)
30
-
31
- self.config = data_config
32
- if is_public:
33
- DB_CONNECTION_STRING = "mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
34
- else:
35
- load_dotenv()
36
- DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
37
-
38
- self.__client = pymongo.MongoClient(DB_CONNECTION_STRING)
39
- self.__db = self.__client["eegdash"]
40
- self.__collection = self.__db["records"]
41
-
42
- self.is_public = is_public
43
- self.filesystem = s3fs.S3FileSystem(
44
- anon=True, client_kwargs={"region_name": "us-east-2"}
45
- )
46
-
47
- def find(self, *args):
48
- results = self.__collection.find(*args)
49
-
50
- # convert to list using get_item on each element
51
- return [result for result in results]
52
-
53
- def exist(self, query: dict):
54
- accepted_query_fields = ["data_name", "dataset"]
55
- assert all(field in accepted_query_fields for field in query.keys())
56
- sessions = self.find(query)
57
- return len(sessions) > 0
58
-
59
- def _validate_input(self, record: dict):
60
- input_types = {
61
- "data_name": str,
62
- "dataset": str,
63
- "bidspath": str,
64
- "subject": str,
65
- "task": str,
66
- "session": str,
67
- "run": str,
68
- "sampling_frequency": float,
69
- "modality": str,
70
- "nchans": int,
71
- "ntimes": int,
72
- "channel_types": list,
73
- "channel_names": list,
74
- }
75
- if "data_name" not in record:
76
- raise ValueError("Missing key: data_name")
77
- # check if args are in the keys and has correct type
78
- for key, value in record.items():
79
- if key not in input_types:
80
- raise ValueError(f"Invalid input: {key}")
81
- if not isinstance(value, input_types[key]):
82
- raise ValueError(f"Invalid input: {key}")
83
-
84
- return record
85
-
86
- def load_eeg_data_from_s3(self, s3path):
87
- with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
88
- with self.filesystem.open(s3path) as s3_file:
89
- tmp.write(s3_file.read())
90
- tmp_path = tmp.name
91
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
92
- os.unlink(tmp_path)
93
- return eeg_data
94
-
95
- def load_eeg_data_from_bids_file(self, bids_file, eeg_attrs=None):
96
- """
97
- bids_file must be a file of the bids_dataset
98
- """
99
- EEG = mne.io.read_raw_eeglab(bids_file)
100
- eeg_data = EEG.get_data()
101
-
102
- fs = EEG.info["sfreq"]
103
- max_time = eeg_data.shape[1] / fs
104
- time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
105
-
106
- channel_names = EEG.ch_names
107
-
108
- eeg_xarray = xr.DataArray(
109
- data=eeg_data,
110
- dims=["channel", "time"],
111
- coords={"time": time_steps, "channel": channel_names},
112
- # attrs=attrs
113
- )
114
- return eeg_xarray
115
-
116
- def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
117
- bids_file = Path(bids_file)
118
- extensions = {
119
- ".set": [".set", ".fdt"], # eeglab
120
- ".edf": [".edf"], # european
121
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
122
- ".bdf": [".bdf"], # biosemi
123
- }
124
- return [
125
- str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
126
- for suffix in extensions[bids_file.suffix]
127
- if bids_file.with_suffix(suffix).exists()
128
- ]
129
-
130
- def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
131
- """
132
- bids_file must be a file of the bids_dataset
133
- """
134
- if bids_file not in bids_dataset.files:
135
- raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
136
-
137
- # Initialize attrs with None values for all expected fields
138
- attrs = {field: None for field in self.config["attributes"].keys()}
139
-
140
- f = os.path.basename(bids_file)
141
- dsnumber = bids_dataset.dataset
142
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
143
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
144
-
145
- # Update with actual values where available
146
- try:
147
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
148
- except Exception as e:
149
- print(f"Error getting participants_tsv: {str(e)}")
150
- participants_tsv = None
151
-
152
- try:
153
- eeg_json = bids_dataset.eeg_json(bids_file)
154
- except Exception as e:
155
- print(f"Error getting eeg_json: {str(e)}")
156
- eeg_json = None
157
-
158
- bids_dependencies_files = self.config["bids_dependencies_files"]
159
- bidsdependencies = []
160
- for extension in bids_dependencies_files:
161
- try:
162
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
163
- dep_path = [
164
- str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
165
- ]
166
- bidsdependencies.extend(dep_path)
167
- except Exception as e:
168
- pass
169
-
170
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
171
-
172
- # Define field extraction functions with error handling
173
- field_extractors = {
174
- "data_name": lambda: f"{bids_dataset.dataset}_{f}",
175
- "dataset": lambda: bids_dataset.dataset,
176
- "bidspath": lambda: openneuro_path,
177
- "subject": lambda: bids_dataset.get_bids_file_attribute(
178
- "subject", bids_file
179
- ),
180
- "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
181
- "session": lambda: bids_dataset.get_bids_file_attribute(
182
- "session", bids_file
183
- ),
184
- "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
185
- "modality": lambda: bids_dataset.get_bids_file_attribute(
186
- "modality", bids_file
187
- ),
188
- "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
189
- "sfreq", bids_file
190
- ),
191
- "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
192
- "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
193
- "participant_tsv": lambda: participants_tsv,
194
- "eeg_json": lambda: eeg_json,
195
- "bidsdependencies": lambda: bidsdependencies,
196
- }
197
-
198
- # Dynamically populate attrs with error handling
199
- for field, extractor in field_extractors.items():
200
- try:
201
- attrs[field] = extractor()
202
- except Exception as e:
203
- print(f"Error extracting {field}: {str(e)}")
204
- attrs[field] = None
205
-
206
- return attrs
207
-
208
- def add_bids_dataset(self, dataset, data_dir, overwrite=True):
209
- """
210
- Create new records for the dataset in the MongoDB database if not found
211
- """
212
- if self.is_public:
213
- raise ValueError("This operation is not allowed for public users")
214
-
215
- if not overwrite and self.exist({"dataset": dataset}):
216
- print(f"Dataset {dataset} already exists in the database")
217
- return
218
- try:
219
- bids_dataset = EEGBIDSDataset(
220
- data_dir=data_dir,
221
- dataset=dataset,
222
- )
223
- except Exception as e:
224
- print(f"Error creating bids dataset {dataset}: {str(e)}")
225
- raise e
226
- requests = []
227
- for bids_file in bids_dataset.get_files():
228
- try:
229
- data_id = f"{dataset}_{os.path.basename(bids_file)}"
230
-
231
- if self.exist({"data_name": data_id}):
232
- if overwrite:
233
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
234
- bids_dataset, bids_file
235
- )
236
- requests.append(self.update_request(eeg_attrs))
237
- else:
238
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
239
- bids_dataset, bids_file
240
- )
241
- requests.append(self.add_request(eeg_attrs))
242
- except:
243
- print("error adding record", bids_file)
244
-
245
- print("Number of database requests", len(requests))
246
-
247
- if requests:
248
- result = self.__collection.bulk_write(requests, ordered=False)
249
- print(f"Inserted: {result.inserted_count}")
250
- print(f"Modified: {result.modified_count}")
251
- print(f"Deleted: {result.deleted_count}")
252
- print(f"Upserted: {result.upserted_count}")
253
- print(f"Errors: {result.bulk_api_result.get('writeErrors', [])}")
254
-
255
- def get(self, query: dict):
256
- """
257
- query: {
258
- 'dataset': 'dsxxxx',
259
-
260
- }"""
261
- sessions = self.find(query)
262
- results = []
263
- if sessions:
264
- print(f"Found {len(sessions)} records")
265
- results = Parallel(
266
- n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
267
- )(
268
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
269
- for session in sessions
270
- )
271
- return results
272
-
273
- def add_request(self, record: dict):
274
- return InsertOne(record)
275
-
276
- def add(self, record: dict):
277
- try:
278
- # input_record = self._validate_input(record)
279
- self.__collection.insert_one(record)
280
- # silent failing
281
- except ValueError as e:
282
- print(f"Failed to validate record: {record['data_name']}")
283
- print(e)
284
- except:
285
- print(f"Error adding record: {record['data_name']}")
286
-
287
- def update_request(self, record: dict):
288
- return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
289
-
290
- def update(self, record: dict):
291
- try:
292
- self.__collection.update_one(
293
- {"data_name": record["data_name"]}, {"$set": record}
294
- )
295
- except: # silent failure
296
- print(f"Error updating record {record['data_name']}")
297
-
298
- def remove_field(self, record, field):
299
- self.__collection.update_one(
300
- {"data_name": record["data_name"]}, {"$unset": {field: 1}}
301
- )
302
-
303
- def remove_field_from_db(self, field):
304
- self.__collection.update_many({}, {"$unset": {field: 1}})
305
-
306
- @property
307
- def collection(self):
308
- return self.__collection
309
-
310
-
311
- class EEGDashDataset(BaseConcatDataset):
312
- # CACHE_DIR = '.eegdash_cache'
313
- def __init__(
314
- self,
315
- query: dict = None,
316
- data_dir: str | list = None,
317
- dataset: str | list = None,
318
- description_fields: list[str] = [
319
- "subject",
320
- "session",
321
- "run",
322
- "task",
323
- "age",
324
- "gender",
325
- "sex",
326
- ],
327
- cache_dir: str = ".eegdash_cache",
328
- **kwargs,
329
- ):
330
- self.cache_dir = cache_dir
331
- if query:
332
- datasets = self.find_datasets(query, description_fields, **kwargs)
333
- elif data_dir:
334
- if type(data_dir) == str:
335
- datasets = self.load_bids_dataset(dataset, data_dir, description_fields)
336
- else:
337
- assert len(data_dir) == len(dataset), (
338
- "Number of datasets and their directories must match"
339
- )
340
- datasets = []
341
- for i in range(len(data_dir)):
342
- datasets.extend(
343
- self.load_bids_dataset(
344
- dataset[i], data_dir[i], description_fields
345
- )
346
- )
347
- # convert to list using get_item on each element
348
- super().__init__(datasets)
349
-
350
- def find_key_in_nested_dict(self, data, target_key):
351
- if isinstance(data, dict):
352
- if target_key in data:
353
- return data[target_key]
354
- for value in data.values():
355
- result = self.find_key_in_nested_dict(value, target_key)
356
- if result is not None:
357
- return result
358
- return None
359
-
360
- def find_datasets(self, query: dict, description_fields: list[str], **kwargs):
361
- eegdashObj = EEGDash()
362
- datasets = []
363
- for record in eegdashObj.find(query):
364
- description = {}
365
- for field in description_fields:
366
- value = self.find_key_in_nested_dict(record, field)
367
- if value is not None:
368
- description[field] = value
369
- datasets.append(
370
- EEGDashBaseDataset(
371
- record, self.cache_dir, description=description, **kwargs
372
- )
373
- )
374
- return datasets
375
-
376
- def load_bids_dataset(
377
- self,
378
- dataset,
379
- data_dir,
380
- description_fields: list[str],
381
- raw_format="eeglab",
382
- **kwargs,
383
- ):
384
- """ """
385
-
386
- def get_base_dataset_from_bids_file(bids_dataset, bids_file):
387
- record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
388
- description = {}
389
- for field in description_fields:
390
- value = self.find_key_in_nested_dict(record, field)
391
- if value is not None:
392
- description[field] = value
393
- return EEGDashBaseDataset(
394
- record, self.cache_dir, description=description, **kwargs
395
- )
396
-
397
- bids_dataset = EEGBIDSDataset(
398
- data_dir=data_dir,
399
- dataset=dataset,
400
- )
401
- eegdashObj = EEGDash()
402
- datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
403
- delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file)
404
- for bids_file in bids_dataset.get_files()
405
- )
406
- return datasets
407
-
408
-
409
- def main():
410
- eegdash = EEGDash()
411
- record = eegdash.find({"dataset": "ds005511", "subject": "NDARUF236HM7"})
412
- print(record)
413
-
414
-
415
- if __name__ == "__main__":
416
- main()
@@ -1,23 +0,0 @@
1
- eegdash/__init__.py,sha256=dyNvSv7ORVDYDz0P-XBNj_SApMlOqwt8LHQqfeuPKCg,105
2
- eegdash/data_config.py,sha256=sIwj7lnZ1hCjeFs-0CXeHn93btm9fX7mwgVTZVeVh-w,763
3
- eegdash/data_utils.py,sha256=LqAJygSpPpYEIerAnWHuHP0OMjd7jQtzXIodbvb0568,19436
4
- eegdash/main.py,sha256=CFI-Bro_oru5iRJdNQZ8IqeRPhrZKXj8wKoMdcrhFt8,14865
5
- eegdash/features/__init__.py,sha256=Ijhc-bLwysyF_HTmdJwbYoTHbxj2wxArs1xSUzhm7Hc,604
6
- eegdash/features/datasets.py,sha256=JB-VTfXTwfbxpgF9wq34gKK69YNCZPQwsnaKEXQisWk,17180
7
- eegdash/features/decorators.py,sha256=iVsbdQXGoLi-V6M9BgP6P8i_UzUtIAWQlf8Qq_LdRqY,1247
8
- eegdash/features/extractors.py,sha256=bITM4DXbW1Dq8Nm8hS3OrSGfRFV6-IwzkTzjiy_yg9k,6816
9
- eegdash/features/serialization.py,sha256=ceGcEvKCg4OsWyLpdAyJsvU1-6UXcvVx2q6nq58vt8Y,2873
10
- eegdash/features/utils.py,sha256=jjVNVLFSXFj3j7NWgEbUlt5faTrWKLLQY9ZYy0xLp_M,3782
11
- eegdash/features/feature_bank/__init__.py,sha256=BKrM3aaggXrfey1yEjEBYaxOV5e3UK-o8oGeB30epOg,149
12
- eegdash/features/feature_bank/complexity.py,sha256=WkLin-f1WTPUtcpkLDObY8nQYRsvpa08Xy9ly1k0hik,3017
13
- eegdash/features/feature_bank/connectivity.py,sha256=bQ6KlxWm5GNpCS9ypLqBUr2L171Yq7wpBQT2tRQKTZ4,2159
14
- eegdash/features/feature_bank/csp.py,sha256=O-kUijM47cOH7yfe7sYL9wT41w1dGaq6sOieh-h82pw,3300
15
- eegdash/features/feature_bank/dimensionality.py,sha256=e8rKpAT_xtZRsBDuVbznFx_daWdQj89Z3Zkt61Hs5qk,3734
16
- eegdash/features/feature_bank/signal.py,sha256=4jgIXRVS274puKfOnDNnqLoBP_yXRyP38iMnXRvobYo,2437
17
- eegdash/features/feature_bank/spectral.py,sha256=bNB7skusePs1gX7NOU6yRlw_Gr4UOCkO_ylkCgybzug,3319
18
- eegdash/features/feature_bank/utils.py,sha256=DGh-Q7-XFIittP7iBBxvsJaZrlVvuY5mw-G7q6C-PCI,1237
19
- eegdash-0.1.0.dist-info/licenses/LICENSE,sha256=Xafu48R-h_kyaNj2tuhfgdEv9_ovciktjUEgRRwMZ6w,812
20
- eegdash-0.1.0.dist-info/METADATA,sha256=RixWQ9dqP1IQzz_HCAZL2Sp-at190rx4ocpvy2DVaio,8551
21
- eegdash-0.1.0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
22
- eegdash-0.1.0.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
23
- eegdash-0.1.0.dist-info/RECORD,,