eegdash 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/__init__.py CHANGED
@@ -1 +1,4 @@
1
- from .main import EEGDash, EEGDashDataset
1
+ from .main import EEGDash, EEGDashDataset
2
+
3
+ __all__ = ["EEGDash", "EEGDashDataset"]
4
+ __version__ = "0.1.0"
eegdash/data_config.py ADDED
@@ -0,0 +1,28 @@
1
+ config = {
2
+ "required_fields": ["data_name"],
3
+ "attributes": {
4
+ "data_name": "str",
5
+ "dataset": "str",
6
+ "bidspath": "str",
7
+ "subject": "str",
8
+ "task": "str",
9
+ "session": "str",
10
+ "run": "str",
11
+ "sampling_frequency": "float",
12
+ "modality": "str",
13
+ "nchans": "int",
14
+ "ntimes": "int",
15
+ },
16
+ "description_fields": ["subject", "session", "run", "task", "age", "gender", "sex"],
17
+ "bids_dependencies_files": [
18
+ "dataset_description.json",
19
+ "participants.tsv",
20
+ "events.tsv",
21
+ "events.json",
22
+ "eeg.json",
23
+ "electrodes.tsv",
24
+ "channels.tsv",
25
+ "coordsystem.json",
26
+ ],
27
+ "accepted_query_fields": ["data_name", "dataset"],
28
+ }
eegdash/data_utils.py CHANGED
@@ -1,23 +1,26 @@
1
+ import json
1
2
  import os
2
- import sys
3
- from joblib import Parallel, delayed
3
+ import re
4
+ import sys
5
+ import tempfile
6
+ from pathlib import Path
7
+
4
8
  import mne
9
+ import mne_bids
5
10
  import numpy as np
6
11
  import pandas as pd
7
- from pathlib import Path
8
- import re
9
- import json
10
- from mne.io import BaseRaw
11
- from mne._fiff.utils import _find_channels, _read_segments_file
12
12
  import s3fs
13
- import tempfile
14
- from mne._fiff.utils import _read_segments_file
15
- from braindecode.datasets import BaseDataset
16
- import mne_bids
13
+ from bids import BIDSLayout
14
+ from joblib import Parallel, delayed
15
+ from mne._fiff.utils import _find_channels, _read_segments_file
16
+ from mne.io import BaseRaw
17
17
  from mne_bids import (
18
18
  BIDSPath,
19
19
  )
20
20
 
21
+ from braindecode.datasets import BaseDataset
22
+
23
+
21
24
  class EEGDashBaseDataset(BaseDataset):
22
25
  """Returns samples from an mne.io.Raw object along with a target.
23
26
 
@@ -37,16 +40,23 @@ class EEGDashBaseDataset(BaseDataset):
37
40
  transform : callable | None
38
41
  On-the-fly transform applied to the example before it is returned.
39
42
  """
40
- AWS_BUCKET = 's3://openneuro.org'
43
+
44
+ AWS_BUCKET = "s3://openneuro.org"
45
+
41
46
  def __init__(self, record, cache_dir, **kwargs):
42
47
  super().__init__(None, **kwargs)
43
48
  self.record = record
44
49
  self.cache_dir = Path(cache_dir)
45
50
  bids_kwargs = self.get_raw_bids_args()
46
- self.bidspath = BIDSPath(root=self.cache_dir / record['dataset'], datatype='eeg', suffix='eeg', **bids_kwargs)
47
- self.s3file = self.get_s3path(record['bidspath'])
48
- self.filecache = self.cache_dir / record['bidspath']
49
- self.bids_dependencies = record['bidsdependencies']
51
+ self.bidspath = BIDSPath(
52
+ root=self.cache_dir / record["dataset"],
53
+ datatype="eeg",
54
+ suffix="eeg",
55
+ **bids_kwargs,
56
+ )
57
+ self.s3file = self.get_s3path(record["bidspath"])
58
+ self.filecache = self.cache_dir / record["bidspath"]
59
+ self.bids_dependencies = record["bidsdependencies"]
50
60
  self._raw = None
51
61
  # if os.path.exists(self.filecache):
52
62
  # self.raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
@@ -56,25 +66,29 @@ class EEGDashBaseDataset(BaseDataset):
56
66
 
57
67
  def _download_s3(self):
58
68
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
59
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
69
+ filesystem = s3fs.S3FileSystem(
70
+ anon=True, client_kwargs={"region_name": "us-east-2"}
71
+ )
60
72
  filesystem.download(self.s3file, self.filecache)
61
73
  self.filenames = [self.filecache]
62
74
 
63
75
  def _download_dependencies(self):
64
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
76
+ filesystem = s3fs.S3FileSystem(
77
+ anon=True, client_kwargs={"region_name": "us-east-2"}
78
+ )
65
79
  for dep in self.bids_dependencies:
66
80
  s3path = self.get_s3path(dep)
67
81
  filepath = self.cache_dir / dep
68
82
  if not filepath.exists():
69
83
  filepath.parent.mkdir(parents=True, exist_ok=True)
70
- filesystem.download(s3path, filepath)
84
+ filesystem.download(s3path, filepath)
71
85
 
72
86
  def get_raw_bids_args(self):
73
- desired_fields = ['subject', 'session', 'task', 'run']
87
+ desired_fields = ["subject", "session", "task", "run"]
74
88
  return {k: self.record[k] for k in desired_fields if self.record[k]}
75
89
 
76
90
  def check_and_get_raw(self):
77
- if not os.path.exists(self.filecache): # not preload
91
+ if not os.path.exists(self.filecache): # not preload
78
92
  if self.bids_dependencies:
79
93
  self._download_dependencies()
80
94
  self._download_s3()
@@ -93,10 +107,10 @@ class EEGDashBaseDataset(BaseDataset):
93
107
  if self.transform is not None:
94
108
  X = self.transform(X)
95
109
  return X, y
96
-
110
+
97
111
  def __len__(self):
98
112
  if self._raw is None:
99
- return self.record['rawdatainfo']['ntimes']
113
+ return int(self.record["ntimes"] * self.record["sampling_frequency"])
100
114
  else:
101
115
  return len(self._raw)
102
116
 
@@ -110,6 +124,7 @@ class EEGDashBaseDataset(BaseDataset):
110
124
  def raw(self, raw):
111
125
  self._raw = raw
112
126
 
127
+
113
128
  class EEGDashBaseRaw(BaseRaw):
114
129
  r"""MNE Raw object from EEG-Dash connection with Openneuro S3 file.
115
130
 
@@ -137,7 +152,8 @@ class EEGDashBaseRaw(BaseRaw):
137
152
  .. versionadded:: 0.11.0
138
153
  """
139
154
 
140
- AWS_BUCKET = 's3://openneuro.org'
155
+ AWS_BUCKET = "s3://openneuro.org"
156
+
141
157
  def __init__(
142
158
  self,
143
159
  input_fname,
@@ -145,24 +161,24 @@ class EEGDashBaseRaw(BaseRaw):
145
161
  eog=(),
146
162
  preload=False,
147
163
  *,
148
- cache_dir='./.eegdash_cache',
149
- bids_dependencies:list = [],
164
+ cache_dir="./.eegdash_cache",
165
+ bids_dependencies: list = [],
150
166
  uint16_codec=None,
151
167
  montage_units="auto",
152
168
  verbose=None,
153
169
  ):
154
- '''
170
+ """
155
171
  Get to work with S3 endpoint first, no caching
156
- '''
172
+ """
157
173
  # Create a simple RawArray
158
- sfreq = metadata['sfreq'] # Sampling frequency
159
- n_times = metadata['n_times']
160
- ch_names = metadata['ch_names']
174
+ sfreq = metadata["sfreq"] # Sampling frequency
175
+ n_times = metadata["n_times"]
176
+ ch_names = metadata["ch_names"]
161
177
  ch_types = []
162
- for ch in metadata['ch_types']:
178
+ for ch in metadata["ch_types"]:
163
179
  chtype = ch.lower()
164
- if chtype == 'heog' or chtype == 'veog':
165
- chtype = 'eog'
180
+ if chtype == "heog" or chtype == "veog":
181
+ chtype = "eog"
166
182
  ch_types.append(chtype)
167
183
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
168
184
  self.s3file = self.get_s3path(input_fname)
@@ -177,7 +193,7 @@ class EEGDashBaseRaw(BaseRaw):
177
193
  super().__init__(
178
194
  info,
179
195
  preload,
180
- last_samps=[n_times-1],
196
+ last_samps=[n_times - 1],
181
197
  orig_format="single",
182
198
  verbose=verbose,
183
199
  )
@@ -187,12 +203,16 @@ class EEGDashBaseRaw(BaseRaw):
187
203
 
188
204
  def _download_s3(self):
189
205
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
190
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
206
+ filesystem = s3fs.S3FileSystem(
207
+ anon=True, client_kwargs={"region_name": "us-east-2"}
208
+ )
191
209
  filesystem.download(self.s3file, self.filecache)
192
210
  self.filenames = [self.filecache]
193
211
 
194
212
  def _download_dependencies(self):
195
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
213
+ filesystem = s3fs.S3FileSystem(
214
+ anon=True, client_kwargs={"region_name": "us-east-2"}
215
+ )
196
216
  for dep in self.bids_dependencies:
197
217
  s3path = self.get_s3path(dep)
198
218
  filepath = self.cache_dir / dep
@@ -203,71 +223,92 @@ class EEGDashBaseRaw(BaseRaw):
203
223
  def _read_segment(
204
224
  self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
205
225
  ):
206
- if not os.path.exists(self.filecache): # not preload
226
+ if not os.path.exists(self.filecache): # not preload
207
227
  if self.bids_dependencies:
208
228
  self._download_dependencies()
209
229
  self._download_s3()
210
- else: # not preload and file is not cached
230
+ else: # not preload and file is not cached
211
231
  self.filenames = [self.filecache]
212
232
  return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
213
-
233
+
214
234
  def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
215
235
  """Read a chunk of data from the file."""
216
236
  _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
217
237
 
218
238
 
219
- class BIDSDataset():
220
- ALLOWED_FILE_FORMAT = ['eeglab', 'brainvision', 'biosemi', 'european']
221
- RAW_EXTENSION = {
222
- 'eeglab': '.set',
223
- 'brainvision': '.vhdr',
224
- 'biosemi': '.bdf',
225
- 'european': '.edf'
239
+ class EEGBIDSDataset:
240
+ ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
241
+ RAW_EXTENSIONS = {
242
+ ".set": [".set", ".fdt"], # eeglab
243
+ ".edf": [".edf"], # european
244
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
245
+ ".bdf": [".bdf"], # biosemi
226
246
  }
227
- METADATA_FILE_EXTENSIONS = ['eeg.json', 'channels.tsv', 'electrodes.tsv', 'events.tsv', 'events.json']
228
- def __init__(self,
229
- data_dir=None, # location of bids dataset
230
- dataset='', # dataset name
231
- raw_format='eeglab', # format of raw data
232
- ):
247
+ METADATA_FILE_EXTENSIONS = [
248
+ "eeg.json",
249
+ "channels.tsv",
250
+ "electrodes.tsv",
251
+ "events.tsv",
252
+ "events.json",
253
+ ]
254
+
255
+ def __init__(
256
+ self,
257
+ data_dir=None, # location of bids dataset
258
+ dataset="", # dataset name
259
+ ):
233
260
  if data_dir is None or not os.path.exists(data_dir):
234
- raise ValueError('data_dir must be specified and must exist')
261
+ raise ValueError("data_dir must be specified and must exist")
235
262
  self.bidsdir = Path(data_dir)
236
263
  self.dataset = dataset
237
264
  assert str(self.bidsdir).endswith(self.dataset)
265
+ self.layout = BIDSLayout(data_dir)
238
266
 
239
- if raw_format.lower() not in self.ALLOWED_FILE_FORMAT:
240
- raise ValueError('raw_format must be one of {}'.format(self.ALLOWED_FILE_FORMAT))
241
- self.raw_format = raw_format.lower()
242
-
243
- # get all .set files in the bids directory
244
- temp_dir = (Path().resolve() / 'data')
245
- if not os.path.exists(temp_dir):
246
- os.mkdir(temp_dir)
247
- if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
248
- self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
249
- np.save(temp_dir / f'{dataset}_files.npy', self.files)
250
- else:
251
- self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
267
+ # get all recording files in the bids directory
268
+ self.files = self.get_recordings(self.layout)
269
+ assert len(self.files) > 0, ValueError(
270
+ "Unable to construct EEG dataset. No EEG recordings found."
271
+ )
272
+ assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
273
+ # temp_dir = (Path().resolve() / 'data')
274
+ # if not os.path.exists(temp_dir):
275
+ # os.mkdir(temp_dir)
276
+ # if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
277
+ # self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
278
+ # np.save(temp_dir / f'{dataset}_files.npy', self.files)
279
+ # else:
280
+ # self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
281
+
282
+ def check_eeg_dataset(self):
283
+ return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
284
+
285
+ def get_recordings(self, layout: BIDSLayout):
286
+ files = []
287
+ for ext, exts in self.RAW_EXTENSIONS.items():
288
+ files = layout.get(extension=ext, return_type="filename")
289
+ if files:
290
+ break
291
+ return files
252
292
 
253
293
  def get_relative_bidspath(self, filename):
254
- bids_parent_dir = self.bidsdir.parent
294
+ bids_parent_dir = self.bidsdir.parent.absolute()
255
295
  return str(Path(filename).relative_to(bids_parent_dir))
256
296
 
257
297
  def get_property_from_filename(self, property, filename):
258
298
  import platform
299
+
259
300
  if platform.system() == "Windows":
260
- lookup = re.search(rf'{property}-(.*?)[_\\]', filename)
301
+ lookup = re.search(rf"{property}-(.*?)[_\\]", filename)
261
302
  else:
262
- lookup = re.search(rf'{property}-(.*?)[_\/]', filename)
263
- return lookup.group(1) if lookup else ''
303
+ lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
304
+ return lookup.group(1) if lookup else ""
264
305
 
265
306
  def merge_json_inheritance(self, json_files):
266
- '''
307
+ """
267
308
  Merge list of json files found by get_bids_file_inheritance,
268
309
  expecting the order (from left to right) is from lowest level to highest level,
269
310
  and return a merged dictionary
270
- '''
311
+ """
271
312
  json_files.reverse()
272
313
  json_dict = {}
273
314
  for f in json_files:
@@ -275,74 +316,73 @@ class BIDSDataset():
275
316
  return json_dict
276
317
 
277
318
  def get_bids_file_inheritance(self, path, basename, extension):
278
- '''
279
- Get all files with given extension that applies to the basename file
319
+ """
320
+ Get all files with given extension that applies to the basename file
280
321
  following the BIDS inheritance principle in the order of lowest level first
281
322
  @param
282
323
  basename: bids file basename without _eeg.set extension for example
283
324
  extension: e.g. channels.tsv
284
- '''
285
- top_level_files = ['README', 'dataset_description.json', 'participants.tsv']
325
+ """
326
+ top_level_files = ["README", "dataset_description.json", "participants.tsv"]
286
327
  bids_files = []
287
328
 
288
329
  # check if path is str object
289
330
  if isinstance(path, str):
290
331
  path = Path(path)
291
332
  if not path.exists:
292
- raise ValueError('path {path} does not exist')
333
+ raise ValueError("path {path} does not exist")
293
334
 
294
335
  # check if file is in current path
295
336
  for file in os.listdir(path):
296
337
  # target_file = path / f"{cur_file_basename}_{extension}"
297
- if os.path.isfile(path/file):
338
+ if os.path.isfile(path / file):
298
339
  # check if file has extension extension
299
340
  # check if file basename has extension
300
341
  if file.endswith(extension):
301
342
  filepath = path / file
302
343
  bids_files.append(filepath)
303
344
 
304
- # cur_file_basename = file[:file.rfind('_')] # TODO: change to just search for any file with extension
305
- # if file.endswith(extension) and cur_file_basename in basename:
306
- # filepath = path / file
307
- # bids_files.append(filepath)
308
-
309
345
  # check if file is in top level directory
310
346
  if any(file in os.listdir(path) for file in top_level_files):
311
347
  return bids_files
312
348
  else:
313
349
  # call get_bids_file_inheritance recursively with parent directory
314
- bids_files.extend(self.get_bids_file_inheritance(path.parent, basename, extension))
350
+ bids_files.extend(
351
+ self.get_bids_file_inheritance(path.parent, basename, extension)
352
+ )
315
353
  return bids_files
316
354
 
317
355
  def get_bids_metadata_files(self, filepath, metadata_file_extension):
318
356
  """
319
357
  (Wrapper for self.get_bids_file_inheritance)
320
358
  Get all BIDS metadata files that are associated with the given filepath, following the BIDS inheritance principle.
321
-
359
+
322
360
  Args:
323
361
  filepath (str or Path): The filepath to get the associated metadata files for.
324
362
  metadata_files_extensions (list): A list of file extensions to search for metadata files.
325
-
363
+
326
364
  Returns:
327
365
  list: A list of filepaths for all the associated metadata files
328
366
  """
329
367
  if isinstance(filepath, str):
330
368
  filepath = Path(filepath)
331
369
  if not filepath.exists:
332
- raise ValueError('filepath {filepath} does not exist')
370
+ raise ValueError("filepath {filepath} does not exist")
333
371
  path, filename = os.path.split(filepath)
334
- basename = filename[:filename.rfind('_')]
372
+ basename = filename[: filename.rfind("_")]
335
373
  # metadata files
336
- meta_files = self.get_bids_file_inheritance(path, basename, metadata_file_extension)
374
+ meta_files = self.get_bids_file_inheritance(
375
+ path, basename, metadata_file_extension
376
+ )
337
377
  return meta_files
338
-
378
+
339
379
  def scan_directory(self, directory, extension):
340
380
  result_files = []
341
- directory_to_ignore = ['.git']
381
+ directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
342
382
  with os.scandir(directory) as entries:
343
383
  for entry in entries:
344
384
  if entry.is_file() and entry.name.endswith(extension):
345
- print('Adding ', entry.path)
385
+ print("Adding ", entry.path)
346
386
  result_files.append(entry.path)
347
387
  elif entry.is_dir():
348
388
  # check that entry path doesn't contain any name in ignore list
@@ -350,7 +390,9 @@ class BIDSDataset():
350
390
  result_files.append(entry.path) # Add directory to scan later
351
391
  return result_files
352
392
 
353
- def get_files_with_extension_parallel(self, directory, extension='.set', max_workers=-1):
393
+ def get_files_with_extension_parallel(
394
+ self, directory, extension=".set", max_workers=-1
395
+ ):
354
396
  result_files = []
355
397
  dirs_to_scan = [directory]
356
398
 
@@ -361,7 +403,7 @@ class BIDSDataset():
361
403
  results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
362
404
  delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
363
405
  )
364
-
406
+
365
407
  # Reset the directories to scan and process the results
366
408
  dirs_to_scan = []
367
409
  for res in results:
@@ -376,8 +418,8 @@ class BIDSDataset():
376
418
 
377
419
  def load_and_preprocess_raw(self, raw_file, preprocess=False):
378
420
  print(f"Loading {raw_file}")
379
- EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose='error')
380
-
421
+ EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
422
+
381
423
  if preprocess:
382
424
  # highpass filter
383
425
  EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
@@ -385,7 +427,7 @@ class BIDSDataset():
385
427
  EEG = EEG.notch_filter(freqs=(60), verbose=False)
386
428
  # bring to common sampling rate
387
429
  sfreq = 128
388
- if EEG.info['sfreq'] != sfreq:
430
+ if EEG.info["sfreq"] != sfreq:
389
431
  EEG = EEG.resample(sfreq)
390
432
  # # normalize data to zero mean and unit variance
391
433
  # scalar = preprocessing.StandardScaler()
@@ -394,12 +436,12 @@ class BIDSDataset():
394
436
  mat_data = EEG.get_data()
395
437
 
396
438
  if len(mat_data.shape) > 2:
397
- raise ValueError('Expect raw data to be CxT dimension')
439
+ raise ValueError("Expect raw data to be CxT dimension")
398
440
  return mat_data
399
-
441
+
400
442
  def get_files(self):
401
443
  return self.files
402
-
444
+
403
445
  def resolve_bids_json(self, json_files: list):
404
446
  """
405
447
  Resolve the BIDS JSON files and return a dictionary of the resolved values.
@@ -410,8 +452,8 @@ class BIDSDataset():
410
452
  dict: A dictionary of the resolved values.
411
453
  """
412
454
  if len(json_files) == 0:
413
- raise ValueError('No JSON files provided')
414
- json_files.reverse() # TODO undeterministic
455
+ raise ValueError("No JSON files provided")
456
+ json_files.reverse() # TODO undeterministic
415
457
 
416
458
  json_dict = {}
417
459
  for json_file in json_files:
@@ -419,63 +461,66 @@ class BIDSDataset():
419
461
  json_dict.update(json.load(f))
420
462
  return json_dict
421
463
 
422
- def sfreq(self, data_filepath):
423
- json_files = self.get_bids_metadata_files(data_filepath, 'eeg.json')
424
- if len(json_files) == 0:
425
- raise ValueError('No eeg.json found')
426
-
427
- metadata = self.resolve_bids_json(json_files)
428
- if 'SamplingFrequency' not in metadata:
429
- raise ValueError('SamplingFrequency not found in metadata')
430
- else:
431
- return metadata['SamplingFrequency']
432
-
433
- def task(self, data_filepath):
434
- return self.get_property_from_filename('task', data_filepath)
435
-
436
- def session(self, data_filepath):
437
- return self.get_property_from_filename('session', data_filepath)
438
-
439
- def run(self, data_filepath):
440
- return self.get_property_from_filename('run', data_filepath)
441
-
442
- def subject(self, data_filepath):
443
- return self.get_property_from_filename('sub', data_filepath)
444
-
445
- def num_channels(self, data_filepath):
446
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
447
- return len(channels_tsv)
464
+ def get_bids_file_attribute(self, attribute, data_filepath):
465
+ entities = self.layout.parse_file_entities(data_filepath)
466
+ bidsfile = self.layout.get(**entities)[0]
467
+ attributes = bidsfile.get_entities(metadata="all")
468
+ attribute_mapping = {
469
+ "sfreq": "SamplingFrequency",
470
+ "modality": "datatype",
471
+ "task": "task",
472
+ "session": "session",
473
+ "run": "run",
474
+ "subject": "subject",
475
+ "ntimes": "RecordingDuration",
476
+ "nchans": "EEGChannelCount",
477
+ }
478
+ attribute_value = attributes.get(attribute_mapping.get(attribute), None)
479
+ return attribute_value
448
480
 
449
481
  def channel_labels(self, data_filepath):
450
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
451
- return channels_tsv['name'].tolist()
452
-
482
+ channels_tsv = pd.read_csv(
483
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
484
+ )
485
+ return channels_tsv["name"].tolist()
486
+
453
487
  def channel_types(self, data_filepath):
454
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
455
- return channels_tsv['type'].tolist()
456
-
488
+ channels_tsv = pd.read_csv(
489
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
490
+ )
491
+ return channels_tsv["type"].tolist()
492
+
457
493
  def num_times(self, data_filepath):
458
- eeg_jsons = self.get_bids_metadata_files(data_filepath, 'eeg.json')
494
+ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
459
495
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
460
- return int(eeg_json_dict['SamplingFrequency'] * eeg_json_dict['RecordingDuration'])
461
-
496
+ return int(
497
+ eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
498
+ )
499
+
462
500
  def subject_participant_tsv(self, data_filepath):
463
- '''Get participants_tsv info of a subject based on filepath'''
464
- participants_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'participants.tsv')[0], sep='\t')
501
+ """Get participants_tsv info of a subject based on filepath"""
502
+ participants_tsv = pd.read_csv(
503
+ self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
504
+ )
505
+ # if participants_tsv is not empty
506
+ if participants_tsv.empty:
507
+ return {}
465
508
  # set 'participant_id' as index
466
- participants_tsv.set_index('participant_id', inplace=True)
467
- subject = f'sub-{self.subject(data_filepath)}'
509
+ participants_tsv.set_index("participant_id", inplace=True)
510
+ subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
468
511
  return participants_tsv.loc[subject].to_dict()
469
-
512
+
470
513
  def eeg_json(self, data_filepath):
471
- eeg_jsons = self.get_bids_metadata_files(data_filepath, 'eeg.json')
514
+ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
472
515
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
473
516
  return eeg_json_dict
474
-
517
+
475
518
  def channel_tsv(self, data_filepath):
476
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
519
+ channels_tsv = pd.read_csv(
520
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
521
+ )
477
522
  channel_tsv = channels_tsv.to_dict()
478
523
  # 'name' and 'type' now have a dictionary of index-value. Convert them to list
479
- for list_field in ['name', 'type', 'units']:
524
+ for list_field in ["name", "type", "units"]:
480
525
  channel_tsv[list_field] = list(channel_tsv[list_field].values())
481
- return channel_tsv
526
+ return channel_tsv
@@ -0,0 +1,25 @@
1
+ # Features datasets
2
+ from .datasets import FeaturesConcatDataset, FeaturesDataset
3
+ from .decorators import (
4
+ FeatureKind,
5
+ FeaturePredecessor,
6
+ bivariate_feature,
7
+ directed_bivariate_feature,
8
+ multivariate_feature,
9
+ univariate_feature,
10
+ )
11
+
12
+ # Feature extraction
13
+ from .extractors import (
14
+ BivariateFeature,
15
+ DirectedBivariateFeature,
16
+ FeatureExtractor,
17
+ FitableFeature,
18
+ MultivariateFeature,
19
+ UnivariateFeature,
20
+ )
21
+
22
+ # Features:
23
+ from .feature_bank import *
24
+ from .serialization import load_features_concat_dataset
25
+ from .utils import extract_features, fit_feature_extractors