eegdash 0.0.9__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/main.py DELETED
@@ -1,359 +0,0 @@
1
- import pymongo
2
- from dotenv import load_dotenv
3
- import os
4
- from pathlib import Path
5
- import s3fs
6
- from joblib import Parallel, delayed
7
- import json
8
- import tempfile
9
- import mne
10
- import numpy as np
11
- import xarray as xr
12
- from .data_utils import EEGBIDSDataset, EEGDashBaseRaw, EEGDashBaseDataset
13
- from .data_config import config as data_config
14
- from braindecode.datasets import BaseDataset, BaseConcatDataset
15
- from collections import defaultdict
16
- from pymongo import MongoClient, InsertOne, UpdateOne, DeleteOne
17
-
18
- class EEGDash:
19
- AWS_BUCKET = 's3://openneuro.org'
20
- def __init__(self,
21
- is_public=True):
22
- # Load config file
23
- # config_path = Path(__file__).parent / 'config.json'
24
- # with open(config_path, 'r') as f:
25
- # self.config = json.load(f)
26
-
27
- self.config = data_config
28
- if is_public:
29
- DB_CONNECTION_STRING="mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
30
- else:
31
- load_dotenv()
32
- DB_CONNECTION_STRING = os.getenv('DB_CONNECTION_STRING')
33
-
34
- self.__client = pymongo.MongoClient(DB_CONNECTION_STRING)
35
- self.__db = self.__client['eegdash']
36
- self.__collection = self.__db['records']
37
-
38
- self.is_public = is_public
39
- self.filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
40
-
41
- def find(self, *args):
42
- results = self.__collection.find(*args)
43
-
44
- # convert to list using get_item on each element
45
- return [result for result in results]
46
-
47
- def exist(self, query:dict):
48
- accepted_query_fields = ['data_name', 'dataset']
49
- assert all(field in accepted_query_fields for field in query.keys())
50
- sessions = self.find(query)
51
- return len(sessions) > 0
52
-
53
- def _validate_input(self, record:dict):
54
- input_types = {
55
- 'data_name': str,
56
- 'dataset': str,
57
- 'bidspath': str,
58
- 'subject': str,
59
- 'task': str,
60
- 'session': str,
61
- 'run': str,
62
- 'sampling_frequency': float,
63
- 'modality': str,
64
- 'nchans': int,
65
- 'ntimes': int,
66
- 'channel_types': list,
67
- 'channel_names': list,
68
- }
69
- if 'data_name' not in record:
70
- raise ValueError("Missing key: data_name")
71
- # check if args are in the keys and has correct type
72
- for key,value in record.items():
73
- if key not in input_types:
74
- raise ValueError(f"Invalid input: {key}")
75
- if not isinstance(value, input_types[key]):
76
- raise ValueError(f"Invalid input: {key}")
77
-
78
- return record
79
-
80
- def load_eeg_data_from_s3(self, s3path):
81
- with tempfile.NamedTemporaryFile(delete=False, suffix='.set') as tmp:
82
- with self.filesystem.open(s3path) as s3_file:
83
- tmp.write(s3_file.read())
84
- tmp_path = tmp.name
85
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
86
- os.unlink(tmp_path)
87
- return eeg_data
88
-
89
- def load_eeg_data_from_bids_file(self, bids_file, eeg_attrs=None):
90
- '''
91
- bids_file must be a file of the bids_dataset
92
- '''
93
- EEG = mne.io.read_raw_eeglab(bids_file)
94
- eeg_data = EEG.get_data()
95
-
96
- fs = EEG.info['sfreq']
97
- max_time = eeg_data.shape[1] / fs
98
- time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
99
-
100
- channel_names = EEG.ch_names
101
-
102
- eeg_xarray = xr.DataArray(
103
- data=eeg_data,
104
- dims=['channel','time'],
105
- coords={
106
- 'time': time_steps,
107
- 'channel': channel_names
108
- },
109
- # attrs=attrs
110
- )
111
- return eeg_xarray
112
-
113
- def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
114
- bids_file = Path(bids_file)
115
- extensions = {
116
- '.set': ['.set', '.fdt'], # eeglab
117
- '.edf': ['.edf'], # european
118
- '.vhdr': ['.eeg', '.vhdr', '.vmrk', '.dat', '.raw'], # brainvision
119
- '.bdf': ['.bdf'], # biosemi
120
- }
121
- return [str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix))) for suffix in extensions[bids_file.suffix] if bids_file.with_suffix(suffix).exists()]
122
-
123
- def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
124
- '''
125
- bids_file must be a file of the bids_dataset
126
- '''
127
- if bids_file not in bids_dataset.files:
128
- raise ValueError(f'{bids_file} not in {bids_dataset.dataset}')
129
-
130
- # Initialize attrs with None values for all expected fields
131
- attrs = {field: None for field in self.config['attributes'].keys()}
132
-
133
- f = os.path.basename(bids_file)
134
- dsnumber = bids_dataset.dataset
135
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
136
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
137
-
138
- # Update with actual values where available
139
- try:
140
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
141
- except Exception as e:
142
- print(f"Error getting participants_tsv: {str(e)}")
143
- participants_tsv = None
144
-
145
- try:
146
- eeg_json = bids_dataset.eeg_json(bids_file)
147
- except Exception as e:
148
- print(f"Error getting eeg_json: {str(e)}")
149
- eeg_json = None
150
-
151
- bids_dependencies_files = self.config['bids_dependencies_files']
152
- bidsdependencies = []
153
- for extension in bids_dependencies_files:
154
- try:
155
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
156
- dep_path = [str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path]
157
- bidsdependencies.extend(dep_path)
158
- except Exception as e:
159
- pass
160
-
161
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
162
-
163
- # Define field extraction functions with error handling
164
- field_extractors = {
165
- 'data_name': lambda: f'{bids_dataset.dataset}_{f}',
166
- 'dataset': lambda: bids_dataset.dataset,
167
- 'bidspath': lambda: openneuro_path,
168
- 'subject': lambda: bids_dataset.get_bids_file_attribute('subject', bids_file),
169
- 'task': lambda: bids_dataset.get_bids_file_attribute('task', bids_file),
170
- 'session': lambda: bids_dataset.get_bids_file_attribute('session', bids_file),
171
- 'run': lambda: bids_dataset.get_bids_file_attribute('run', bids_file),
172
- 'modality': lambda: bids_dataset.get_bids_file_attribute('modality', bids_file),
173
- 'sampling_frequency': lambda: bids_dataset.get_bids_file_attribute('sfreq', bids_file),
174
- 'nchans': lambda: bids_dataset.get_bids_file_attribute('nchans', bids_file),
175
- 'ntimes': lambda: bids_dataset.get_bids_file_attribute('ntimes', bids_file),
176
- 'participant_tsv': lambda: participants_tsv,
177
- 'eeg_json': lambda: eeg_json,
178
- 'bidsdependencies': lambda: bidsdependencies,
179
- }
180
-
181
- # Dynamically populate attrs with error handling
182
- for field, extractor in field_extractors.items():
183
- try:
184
- attrs[field] = extractor()
185
- except Exception as e:
186
- print(f"Error extracting {field}: {str(e)}")
187
- attrs[field] = None
188
-
189
- return attrs
190
-
191
- def add_bids_dataset(self, dataset, data_dir, overwrite=True):
192
- '''
193
- Create new records for the dataset in the MongoDB database if not found
194
- '''
195
- if self.is_public:
196
- raise ValueError('This operation is not allowed for public users')
197
-
198
- if not overwrite and self.exist({'dataset': dataset}):
199
- print(f'Dataset {dataset} already exists in the database')
200
- return
201
- try:
202
- bids_dataset = EEGBIDSDataset(
203
- data_dir=data_dir,
204
- dataset=dataset,
205
- )
206
- except Exception as e:
207
- print(f'Error creating bids dataset {dataset}: {str(e)}')
208
- raise e
209
- requests = []
210
- for bids_file in bids_dataset.get_files():
211
- try:
212
- data_id = f"{dataset}_{os.path.basename(bids_file)}"
213
-
214
- if self.exist({'data_name':data_id}):
215
- if overwrite:
216
- eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
217
- requests.append(self.update_request(eeg_attrs))
218
- else:
219
- eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
220
- requests.append(self.add_request(eeg_attrs))
221
- except:
222
- print('error adding record', bids_file)
223
-
224
- print('Number of database requests', len(requests))
225
-
226
- if requests:
227
- result = self.__collection.bulk_write(requests, ordered=False)
228
- print(f"Inserted: {result.inserted_count}")
229
- print(f"Modified: {result.modified_count}")
230
- print(f"Deleted: {result.deleted_count}")
231
- print(f"Upserted: {result.upserted_count}")
232
- print(f"Errors: {result.bulk_api_result.get('writeErrors', [])}")
233
-
234
- def get(self, query:dict):
235
- '''
236
- query: {
237
- 'dataset': 'dsxxxx',
238
-
239
- }'''
240
- sessions = self.find(query)
241
- results = []
242
- if sessions:
243
- print(f'Found {len(sessions)} records')
244
- results = Parallel(n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1)(
245
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session)) for session in sessions
246
- )
247
- return results
248
-
249
- def add_request(self, record:dict):
250
- return InsertOne(record)
251
-
252
- def add(self, record:dict):
253
- try:
254
- # input_record = self._validate_input(record)
255
- self.__collection.insert_one(record)
256
- # silent failing
257
- except ValueError as e:
258
- print(f"Failed to validate record: {record['data_name']}")
259
- print(e)
260
- except:
261
- print(f"Error adding record: {record['data_name']}")
262
-
263
- def update_request(self, record:dict):
264
- return UpdateOne({'data_name': record['data_name']}, {'$set': record})
265
-
266
- def update(self, record:dict):
267
- try:
268
- self.__collection.update_one({'data_name': record['data_name']}, {'$set': record})
269
- except: # silent failure
270
- print(f'Error updating record {record["data_name"]}')
271
-
272
- def remove_field(self, record, field):
273
- self.__collection.update_one({'data_name': record['data_name']}, {'$unset': {field: 1}})
274
-
275
- def remove_field_from_db(self, field):
276
- self.__collection.update_many({}, {'$unset': {field: 1}})
277
-
278
- @property
279
- def collection(self):
280
- return self.__collection
281
-
282
- class EEGDashDataset(BaseConcatDataset):
283
- # CACHE_DIR = '.eegdash_cache'
284
- def __init__(
285
- self,
286
- query:dict=None,
287
- data_dir:str | list =None,
288
- dataset:str | list =None,
289
- description_fields: list[str]=['subject', 'session', 'run', 'task', 'age', 'gender', 'sex'],
290
- cache_dir:str='.eegdash_cache',
291
- **kwargs
292
- ):
293
- self.cache_dir = cache_dir
294
- if query:
295
- datasets = self.find_datasets(query, description_fields, **kwargs)
296
- elif data_dir:
297
- if type(data_dir) == str:
298
- datasets = self.load_bids_dataset(dataset, data_dir, description_fields)
299
- else:
300
- assert len(data_dir) == len(dataset), 'Number of datasets and their directories must match'
301
- datasets = []
302
- for i in range(len(data_dir)):
303
- datasets.extend(self.load_bids_dataset(dataset[i], data_dir[i], description_fields))
304
- # convert to list using get_item on each element
305
- super().__init__(datasets)
306
-
307
-
308
- def find_key_in_nested_dict(self, data, target_key):
309
- if isinstance(data, dict):
310
- if target_key in data:
311
- return data[target_key]
312
- for value in data.values():
313
- result = self.find_key_in_nested_dict(value, target_key)
314
- if result is not None:
315
- return result
316
- return None
317
-
318
- def find_datasets(self, query:dict, description_fields:list[str], **kwargs):
319
- eegdashObj = EEGDash()
320
- datasets = []
321
- for record in eegdashObj.find(query):
322
- description = {}
323
- for field in description_fields:
324
- value = self.find_key_in_nested_dict(record, field)
325
- if value:
326
- description[field] = value
327
- datasets.append(EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs))
328
- return datasets
329
-
330
- def load_bids_dataset(self, dataset, data_dir, description_fields: list[str],raw_format='eeglab', **kwargs):
331
- '''
332
- '''
333
- def get_base_dataset_from_bids_file(bids_dataset, bids_file):
334
- record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
335
- description = {}
336
- for field in description_fields:
337
- value = self.find_key_in_nested_dict(record, field)
338
- if value:
339
- description[field] = value
340
- return EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs)
341
-
342
- bids_dataset = EEGBIDSDataset(
343
- data_dir=data_dir,
344
- dataset=dataset,
345
- raw_format=raw_format,
346
- )
347
- eegdashObj = EEGDash()
348
- datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
349
- delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file) for bids_file in bids_dataset.get_files()
350
- )
351
- return datasets
352
-
353
- def main():
354
- eegdash = EEGDash()
355
- record = eegdash.find({'dataset': 'ds005511', 'subject': 'NDARUF236HM7'})
356
- print(record)
357
-
358
- if __name__ == '__main__':
359
- main()
@@ -1,22 +0,0 @@
1
- eegdash/__init__.py,sha256=DrliW5AazWcHJBznrmrS_YF8n8K48csOzfWWIvB6Esw,41
2
- eegdash/data_config.py,sha256=1ecgAPP4ryKJAZNX40MFLioZuG4bKTwsx-QW7L9K5nw,676
3
- eegdash/data_utils.py,sha256=NUQgMM98h6FNh6mncmWRrFEVS8s4yGx2Jg2brK_Wmv8,19256
4
- eegdash/main.py,sha256=mWlvJcVzkPtXbBf_bTTj86C3b-xI_alsmtk6Ez-gXRY,14171
5
- eegdash/features/__init__.py,sha256=6982FfzIkZ7nsAkE5d1RIDjsAEUYr8g2QWPyHpHr-Ak,604
6
- eegdash/features/datasets.py,sha256=6X6T_B8jBcuFQ-DL2TMe87ejUJuQj1cdfL3Ydt5-UZE,17177
7
- eegdash/features/decorators.py,sha256=9jdYifJhazTyklWMuUhsGgX_wW9_Ji6xY00tFPDiwFE,1266
8
- eegdash/features/extractors.py,sha256=kKhMXicAAunTCUHDvA_j275AS79E9pQTuaeMZ9aAB9o,6815
9
- eegdash/features/serialization.py,sha256=vLk5xtBqdv7UnTas_lyI6tlswkQFjV_-TaWMW2g8DLQ,2873
10
- eegdash/features/utils.py,sha256=yIqdT4DLsdf5zD8HE8bnn8hbaliUVB2v2xR0uTGPA_M,3781
11
- eegdash/features/feature_bank/__init__.py,sha256=uBHFHLmS4-bY5fL9whO1d15AiwMxB-U14sWFcArAL4o,149
12
- eegdash/features/feature_bank/complexity.py,sha256=w-0X_LPO2PlyGFfy10EwkoiKtgJ5KJk1cC7lnBDGVOM,3018
13
- eegdash/features/feature_bank/connectivity.py,sha256=egh5Iw-bnjNITuzEUnfxaqLKUB_tGxDROAgbk2MHvWg,2808
14
- eegdash/features/feature_bank/csp.py,sha256=I2u65vj_Vb-yF8iwUosuWzWbLTXm5_67_LOGrsqP6EU,3301
15
- eegdash/features/feature_bank/dimensionality.py,sha256=3-t4OLSMs1Khc-QYVz8L_jvjKjxLh6Wa_w6HeYhuX0U,3735
16
- eegdash/features/feature_bank/signal.py,sha256=eaTO_cPSytwLjadHShA4DqbZH8Q5QENXVnfOwyvPuWg,2437
17
- eegdash/features/feature_bank/spectral.py,sha256=NkKmkS9hoiJkyn4oXwRwOSwTyxtIHxe12KVIYTjeXb0,3723
18
- eegdash-0.0.9.dist-info/licenses/LICENSE,sha256=Xafu48R-h_kyaNj2tuhfgdEv9_ovciktjUEgRRwMZ6w,812
19
- eegdash-0.0.9.dist-info/METADATA,sha256=MFXDJ87JQjHXB03QELbLtHcMWL5wZ0OnRDXCexMZ-Yc,8555
20
- eegdash-0.0.9.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
21
- eegdash-0.0.9.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
22
- eegdash-0.0.9.dist-info/RECORD,,