eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/main.py CHANGED
@@ -1,69 +1,81 @@
1
- from typing import List
2
- import pymongo
3
- from dotenv import load_dotenv
1
+ import json
4
2
  import os
5
- from pathlib import Path
6
- import s3fs
7
- from joblib import Parallel, delayed
8
3
  import tempfile
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+
9
7
  import mne
10
8
  import numpy as np
9
+ import pymongo
10
+ import s3fs
11
11
  import xarray as xr
12
- from .data_utils import BIDSDataset, EEGDashBaseRaw, EEGDashBaseDataset
13
- from braindecode.datasets import BaseDataset, BaseConcatDataset
14
- from collections import defaultdict
15
- from pymongo import MongoClient, InsertOne, UpdateOne, DeleteOne
12
+ from dotenv import load_dotenv
13
+ from joblib import Parallel, delayed
14
+ from pymongo import DeleteOne, InsertOne, MongoClient, UpdateOne
15
+
16
+ from braindecode.datasets import BaseConcatDataset, BaseDataset
17
+
18
+ from .data_config import config as data_config
19
+ from .data_utils import EEGBIDSDataset, EEGDashBaseDataset, EEGDashBaseRaw
20
+
16
21
 
17
22
  class EEGDash:
18
- AWS_BUCKET = 's3://openneuro.org'
19
- def __init__(self,
20
- is_public=True):
23
+ AWS_BUCKET = "s3://openneuro.org"
24
+
25
+ def __init__(self, is_public=True):
26
+ # Load config file
27
+ # config_path = Path(__file__).parent / 'config.json'
28
+ # with open(config_path, 'r') as f:
29
+ # self.config = json.load(f)
30
+
31
+ self.config = data_config
21
32
  if is_public:
22
- DB_CONNECTION_STRING="mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
33
+ DB_CONNECTION_STRING = "mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
23
34
  else:
24
35
  load_dotenv()
25
- DB_CONNECTION_STRING = os.getenv('DB_CONNECTION_STRING')
36
+ DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
26
37
 
27
38
  self.__client = pymongo.MongoClient(DB_CONNECTION_STRING)
28
- self.__db = self.__client['eegdash']
29
- self.__collection = self.__db['records']
39
+ self.__db = self.__client["eegdash"]
40
+ self.__collection = self.__db["records"]
30
41
 
31
42
  self.is_public = is_public
32
- self.filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
33
-
43
+ self.filesystem = s3fs.S3FileSystem(
44
+ anon=True, client_kwargs={"region_name": "us-east-2"}
45
+ )
46
+
34
47
  def find(self, *args):
35
48
  results = self.__collection.find(*args)
36
-
49
+
37
50
  # convert to list using get_item on each element
38
51
  return [result for result in results]
39
52
 
40
- def exist(self, data_name=''):
41
- query = {
42
- "data_name": data_name
43
- }
53
+ def exist(self, query: dict):
54
+ accepted_query_fields = ["data_name", "dataset"]
55
+ assert all(field in accepted_query_fields for field in query.keys())
44
56
  sessions = self.find(query)
45
57
  return len(sessions) > 0
46
58
 
47
- def _validate_input(self, record:dict):
59
+ def _validate_input(self, record: dict):
48
60
  input_types = {
49
- 'data_name': str,
50
- 'dataset': str,
51
- 'bidspath': str,
52
- 'subject': str,
53
- 'task': str,
54
- 'session': str,
55
- 'run': str,
56
- 'sampling_frequency': float,
57
- 'modality': str,
58
- 'nchans': int,
59
- 'ntimes': int,
60
- 'channel_types': list,
61
- 'channel_names': list,
61
+ "data_name": str,
62
+ "dataset": str,
63
+ "bidspath": str,
64
+ "subject": str,
65
+ "task": str,
66
+ "session": str,
67
+ "run": str,
68
+ "sampling_frequency": float,
69
+ "modality": str,
70
+ "nchans": int,
71
+ "ntimes": int,
72
+ "channel_types": list,
73
+ "channel_names": list,
62
74
  }
63
- if 'data_name' not in record:
75
+ if "data_name" not in record:
64
76
  raise ValueError("Missing key: data_name")
65
77
  # check if args are in the keys and has correct type
66
- for key,value in record.items():
78
+ for key, value in record.items():
67
79
  if key not in input_types:
68
80
  raise ValueError(f"Invalid input: {key}")
69
81
  if not isinstance(value, input_types[key]):
@@ -72,7 +84,7 @@ class EEGDash:
72
84
  return record
73
85
 
74
86
  def load_eeg_data_from_s3(self, s3path):
75
- with tempfile.NamedTemporaryFile(delete=False, suffix='.set') as tmp:
87
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
76
88
  with self.filesystem.open(s3path) as s3_file:
77
89
  tmp.write(s3_file.read())
78
90
  tmp_path = tmp.name
@@ -80,97 +92,157 @@ class EEGDash:
80
92
  os.unlink(tmp_path)
81
93
  return eeg_data
82
94
 
83
- def load_eeg_data_from_bids_file(self, bids_file, eeg_attrs=None):
84
- '''
95
+ def load_eeg_data_from_bids_file(self, bids_file, eeg_attrs=None):
96
+ """
85
97
  bids_file must be a file of the bids_dataset
86
- '''
98
+ """
87
99
  EEG = mne.io.read_raw_eeglab(bids_file)
88
100
  eeg_data = EEG.get_data()
89
-
90
- fs = EEG.info['sfreq']
101
+
102
+ fs = EEG.info["sfreq"]
91
103
  max_time = eeg_data.shape[1] / fs
92
- time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
104
+ time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
93
105
 
94
106
  channel_names = EEG.ch_names
95
107
 
96
108
  eeg_xarray = xr.DataArray(
97
109
  data=eeg_data,
98
- dims=['channel','time'],
99
- coords={
100
- 'time': time_steps,
101
- 'channel': channel_names
102
- },
110
+ dims=["channel", "time"],
111
+ coords={"time": time_steps, "channel": channel_names},
103
112
  # attrs=attrs
104
113
  )
105
114
  return eeg_xarray
106
115
 
107
- def load_eeg_attrs_from_bids_file(self, bids_dataset: BIDSDataset, bids_file):
108
- '''
116
+ def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
117
+ bids_file = Path(bids_file)
118
+ extensions = {
119
+ ".set": [".set", ".fdt"], # eeglab
120
+ ".edf": [".edf"], # european
121
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
122
+ ".bdf": [".bdf"], # biosemi
123
+ }
124
+ return [
125
+ str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
126
+ for suffix in extensions[bids_file.suffix]
127
+ if bids_file.with_suffix(suffix).exists()
128
+ ]
129
+
130
+ def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
131
+ """
109
132
  bids_file must be a file of the bids_dataset
110
- '''
133
+ """
111
134
  if bids_file not in bids_dataset.files:
112
- raise ValueError(f'{bids_file} not in {bids_dataset.dataset}')
135
+ raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
136
+
137
+ # Initialize attrs with None values for all expected fields
138
+ attrs = {field: None for field in self.config["attributes"].keys()}
139
+
113
140
  f = os.path.basename(bids_file)
114
141
  dsnumber = bids_dataset.dataset
115
142
  # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
116
143
  openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
117
144
 
118
- bids_dependencies_files = ['dataset_description.json', 'participants.tsv', 'events.tsv', 'events.json', 'eeg.json', 'electrodes.tsv', 'channels.tsv', 'coordsystem.json']
145
+ # Update with actual values where available
146
+ try:
147
+ participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
148
+ except Exception as e:
149
+ print(f"Error getting participants_tsv: {str(e)}")
150
+ participants_tsv = None
151
+
152
+ try:
153
+ eeg_json = bids_dataset.eeg_json(bids_file)
154
+ except Exception as e:
155
+ print(f"Error getting eeg_json: {str(e)}")
156
+ eeg_json = None
157
+
158
+ bids_dependencies_files = self.config["bids_dependencies_files"]
119
159
  bidsdependencies = []
120
160
  for extension in bids_dependencies_files:
121
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
122
- dep_path = [str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path]
123
-
124
- bidsdependencies.extend(dep_path)
125
-
126
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
127
- eeg_json = bids_dataset.eeg_json(bids_file)
128
- attrs = {
129
- 'data_name': f'{bids_dataset.dataset}_{f}',
130
- 'dataset': bids_dataset.dataset,
131
- 'bidspath': openneuro_path,
132
- 'subject': bids_dataset.subject(bids_file),
133
- 'task': bids_dataset.task(bids_file),
134
- 'session': bids_dataset.session(bids_file),
135
- 'run': bids_dataset.run(bids_file),
136
- 'modality': 'EEG',
137
- 'nchans': bids_dataset.num_channels(bids_file),
138
- 'ntimes': bids_dataset.num_times(bids_file),
139
- 'participant_tsv': participants_tsv,
140
- 'eeg_json': eeg_json,
141
- 'bidsdependencies': bidsdependencies,
161
+ try:
162
+ dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
163
+ dep_path = [
164
+ str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
165
+ ]
166
+ bidsdependencies.extend(dep_path)
167
+ except Exception as e:
168
+ pass
169
+
170
+ bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
171
+
172
+ # Define field extraction functions with error handling
173
+ field_extractors = {
174
+ "data_name": lambda: f"{bids_dataset.dataset}_{f}",
175
+ "dataset": lambda: bids_dataset.dataset,
176
+ "bidspath": lambda: openneuro_path,
177
+ "subject": lambda: bids_dataset.get_bids_file_attribute(
178
+ "subject", bids_file
179
+ ),
180
+ "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
181
+ "session": lambda: bids_dataset.get_bids_file_attribute(
182
+ "session", bids_file
183
+ ),
184
+ "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
185
+ "modality": lambda: bids_dataset.get_bids_file_attribute(
186
+ "modality", bids_file
187
+ ),
188
+ "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
189
+ "sfreq", bids_file
190
+ ),
191
+ "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
192
+ "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
193
+ "participant_tsv": lambda: participants_tsv,
194
+ "eeg_json": lambda: eeg_json,
195
+ "bidsdependencies": lambda: bidsdependencies,
142
196
  }
143
197
 
198
+ # Dynamically populate attrs with error handling
199
+ for field, extractor in field_extractors.items():
200
+ try:
201
+ attrs[field] = extractor()
202
+ except Exception as e:
203
+ print(f"Error extracting {field}: {str(e)}")
204
+ attrs[field] = None
205
+
144
206
  return attrs
145
207
 
146
- def add_bids_dataset(self, dataset, data_dir, raw_format='eeglab', overwrite=True):
147
- '''
208
+ def add_bids_dataset(self, dataset, data_dir, overwrite=True):
209
+ """
148
210
  Create new records for the dataset in the MongoDB database if not found
149
- '''
211
+ """
150
212
  if self.is_public:
151
- raise ValueError('This operation is not allowed for public users')
213
+ raise ValueError("This operation is not allowed for public users")
152
214
 
153
- bids_dataset = BIDSDataset(
154
- data_dir=data_dir,
155
- dataset=dataset,
156
- raw_format=raw_format,
157
- )
215
+ if not overwrite and self.exist({"dataset": dataset}):
216
+ print(f"Dataset {dataset} already exists in the database")
217
+ return
218
+ try:
219
+ bids_dataset = EEGBIDSDataset(
220
+ data_dir=data_dir,
221
+ dataset=dataset,
222
+ )
223
+ except Exception as e:
224
+ print(f"Error creating bids dataset {dataset}: {str(e)}")
225
+ raise e
158
226
  requests = []
159
227
  for bids_file in bids_dataset.get_files():
160
228
  try:
161
229
  data_id = f"{dataset}_{os.path.basename(bids_file)}"
162
230
 
163
- if self.exist(data_name=data_id):
231
+ if self.exist({"data_name": data_id}):
164
232
  if overwrite:
165
- eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
166
- requests.append(UpdateOne(self.update_request(eeg_attrs)))
233
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
234
+ bids_dataset, bids_file
235
+ )
236
+ requests.append(self.update_request(eeg_attrs))
167
237
  else:
168
- eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
238
+ eeg_attrs = self.load_eeg_attrs_from_bids_file(
239
+ bids_dataset, bids_file
240
+ )
169
241
  requests.append(self.add_request(eeg_attrs))
170
242
  except:
171
- print('error adding record', bids_file)
243
+ print("error adding record", bids_file)
172
244
 
173
- print('Number of database requests', len(requests))
245
+ print("Number of database requests", len(requests))
174
246
 
175
247
  if requests:
176
248
  result = self.__collection.bulk_write(requests, ordered=False)
@@ -180,25 +252,28 @@ class EEGDash:
180
252
  print(f"Upserted: {result.upserted_count}")
181
253
  print(f"Errors: {result.bulk_api_result.get('writeErrors', [])}")
182
254
 
183
- def get(self, query:dict):
184
- '''
255
+ def get(self, query: dict):
256
+ """
185
257
  query: {
186
258
  'dataset': 'dsxxxx',
187
259
 
188
- }'''
260
+ }"""
189
261
  sessions = self.find(query)
190
262
  results = []
191
263
  if sessions:
192
- print(f'Found {len(sessions)} records')
193
- results = Parallel(n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1)(
194
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session)) for session in sessions
264
+ print(f"Found {len(sessions)} records")
265
+ results = Parallel(
266
+ n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
267
+ )(
268
+ delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
269
+ for session in sessions
195
270
  )
196
271
  return results
197
272
 
198
- def add_request(self, record:dict):
273
+ def add_request(self, record: dict):
199
274
  return InsertOne(record)
200
275
 
201
- def add(self, record:dict):
276
+ def add(self, record: dict):
202
277
  try:
203
278
  # input_record = self._validate_input(record)
204
279
  self.__collection.insert_one(record)
@@ -206,48 +281,72 @@ class EEGDash:
206
281
  except ValueError as e:
207
282
  print(f"Failed to validate record: {record['data_name']}")
208
283
  print(e)
209
- except:
284
+ except:
210
285
  print(f"Error adding record: {record['data_name']}")
211
286
 
212
- def update_request(self, record:dict):
213
- return UpdateOne({'data_name': record['data_name']}, {'$set': record})
287
+ def update_request(self, record: dict):
288
+ return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
214
289
 
215
- def update(self, record:dict):
290
+ def update(self, record: dict):
216
291
  try:
217
- self.__collection.update_one({'data_name': record['data_name']}, {'$set': record})
218
- except: # silent failure
219
- print(f'Error updating record {record["data_name"]}')
292
+ self.__collection.update_one(
293
+ {"data_name": record["data_name"]}, {"$set": record}
294
+ )
295
+ except: # silent failure
296
+ print(f"Error updating record {record['data_name']}")
220
297
 
221
298
  def remove_field(self, record, field):
222
- self.__collection.update_one({'data_name': record['data_name']}, {'$unset': {field: 1}})
223
-
299
+ self.__collection.update_one(
300
+ {"data_name": record["data_name"]}, {"$unset": {field: 1}}
301
+ )
302
+
224
303
  def remove_field_from_db(self, field):
225
- self.__collection.update_many({}, {'$unset': {field: 1}})
226
-
304
+ self.__collection.update_many({}, {"$unset": {field: 1}})
305
+
306
+ @property
307
+ def collection(self):
308
+ return self.__collection
309
+
227
310
 
228
311
  class EEGDashDataset(BaseConcatDataset):
229
- CACHE_DIR = '.eegdash_cache'
312
+ # CACHE_DIR = '.eegdash_cache'
230
313
  def __init__(
231
314
  self,
232
- query:dict=None,
233
- data_dir:str | list =None,
234
- dataset:str | list =None,
235
- description_fields: list[str]=['subject', 'session', 'run', 'task', 'age', 'gender', 'sex'],
236
- **kwargs
315
+ query: dict = None,
316
+ data_dir: str | list = None,
317
+ dataset: str | list = None,
318
+ description_fields: list[str] = [
319
+ "subject",
320
+ "session",
321
+ "run",
322
+ "task",
323
+ "age",
324
+ "gender",
325
+ "sex",
326
+ ],
327
+ cache_dir: str = ".eegdash_cache",
328
+ **kwargs,
237
329
  ):
330
+ self.cache_dir = cache_dir
238
331
  if query:
239
332
  datasets = self.find_datasets(query, description_fields, **kwargs)
240
333
  elif data_dir:
241
334
  if type(data_dir) == str:
242
335
  datasets = self.load_bids_dataset(dataset, data_dir, description_fields)
243
336
  else:
244
- assert len(data_dir) == len(dataset), 'Number of datasets and their directories must match'
337
+ assert len(data_dir) == len(dataset), (
338
+ "Number of datasets and their directories must match"
339
+ )
245
340
  datasets = []
246
341
  for i in range(len(data_dir)):
247
- datasets.extend(self.load_bids_dataset(dataset[i], data_dir[i], description_fields))
342
+ datasets.extend(
343
+ self.load_bids_dataset(
344
+ dataset[i], data_dir[i], description_fields
345
+ )
346
+ )
248
347
  # convert to list using get_item on each element
249
348
  super().__init__(datasets)
250
-
349
+
251
350
  def find_key_in_nested_dict(self, data, target_key):
252
351
  if isinstance(data, dict):
253
352
  if target_key in data:
@@ -258,54 +357,60 @@ class EEGDashDataset(BaseConcatDataset):
258
357
  return result
259
358
  return None
260
359
 
261
- def find_datasets(self, query:dict, description_fields:list[str], **kwargs):
360
+ def find_datasets(self, query: dict, description_fields: list[str], **kwargs):
262
361
  eegdashObj = EEGDash()
263
362
  datasets = []
264
363
  for record in eegdashObj.find(query):
265
364
  description = {}
266
365
  for field in description_fields:
267
366
  value = self.find_key_in_nested_dict(record, field)
268
- if value:
367
+ if value is not None:
269
368
  description[field] = value
270
- datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
369
+ datasets.append(
370
+ EEGDashBaseDataset(
371
+ record, self.cache_dir, description=description, **kwargs
372
+ )
373
+ )
271
374
  return datasets
272
375
 
273
- def load_bids_dataset(self, dataset, data_dir, description_fields: list[str],raw_format='eeglab', **kwargs):
274
- '''
275
- '''
376
+ def load_bids_dataset(
377
+ self,
378
+ dataset,
379
+ data_dir,
380
+ description_fields: list[str],
381
+ raw_format="eeglab",
382
+ **kwargs,
383
+ ):
384
+ """ """
385
+
276
386
  def get_base_dataset_from_bids_file(bids_dataset, bids_file):
277
387
  record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
278
388
  description = {}
279
389
  for field in description_fields:
280
390
  value = self.find_key_in_nested_dict(record, field)
281
- if value:
391
+ if value is not None:
282
392
  description[field] = value
283
- return EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs)
393
+ return EEGDashBaseDataset(
394
+ record, self.cache_dir, description=description, **kwargs
395
+ )
284
396
 
285
- bids_dataset = BIDSDataset(
397
+ bids_dataset = EEGBIDSDataset(
286
398
  data_dir=data_dir,
287
399
  dataset=dataset,
288
- raw_format=raw_format,
289
400
  )
290
401
  eegdashObj = EEGDash()
291
402
  datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
292
- delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file) for bids_file in bids_dataset.get_files()
293
- )
294
- # datasets = []
295
- # for bids_file in bids_dataset.get_files():
296
- # record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
297
- # description = {}
298
- # for field in description_fields:
299
- # value = self.find_key_in_nested_dict(record, field)
300
- # if value:
301
- # description[field] = value
302
- # datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
403
+ delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file)
404
+ for bids_file in bids_dataset.get_files()
405
+ )
303
406
  return datasets
304
407
 
408
+
305
409
  def main():
306
410
  eegdash = EEGDash()
307
- record = eegdash.find({'dataset': 'ds005511', 'subject': 'NDARUF236HM7'})
411
+ record = eegdash.find({"dataset": "ds005511", "subject": "NDARUF236HM7"})
308
412
  print(record)
309
413
 
310
- if __name__ == '__main__':
311
- main()
414
+
415
+ if __name__ == "__main__":
416
+ main()