eegdash 0.0.9__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_utils.py CHANGED
@@ -1,90 +1,122 @@
1
+ import json
2
+ import logging
1
3
  import os
2
- import sys
3
- from joblib import Parallel, delayed
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any
7
+
4
8
  import mne
9
+ import mne_bids
5
10
  import numpy as np
6
11
  import pandas as pd
7
- from pathlib import Path
8
- import re
9
- import json
10
- from mne.io import BaseRaw
11
- from mne._fiff.utils import _find_channels, _read_segments_file
12
12
  import s3fs
13
- import tempfile
13
+ from bids import BIDSLayout
14
+ from joblib import Parallel, delayed
14
15
  from mne._fiff.utils import _read_segments_file
15
- from braindecode.datasets import BaseDataset
16
- import mne_bids
16
+ from mne.io import BaseRaw
17
17
  from mne_bids import (
18
18
  BIDSPath,
19
19
  )
20
- from bids import BIDSLayout
21
20
 
22
- class EEGDashBaseDataset(BaseDataset):
23
- """Returns samples from an mne.io.Raw object along with a target.
21
+ from braindecode.datasets import BaseDataset
24
22
 
25
- Dataset which serves samples from an mne.io.Raw object along with a target.
26
- The target is unique for the dataset, and is obtained through the
27
- `description` attribute.
23
+ logger = logging.getLogger("eegdash")
28
24
 
29
- Parameters
30
- ----------
31
- raw : mne.io.Raw
32
- Continuous data.
33
- description : dict | pandas.Series | None
34
- Holds additional description about the continuous signal / subject.
35
- target_name : str | tuple | None
36
- Name(s) of the index in `description` that should be used to provide the
37
- target (e.g., to be used in a prediction task later on).
38
- transform : callable | None
39
- On-the-fly transform applied to the example before it is returned.
25
+
26
+ class EEGDashBaseDataset(BaseDataset):
27
+ """A single EEG recording hosted on AWS S3 and cached locally upon first access.
28
+
29
+ This is a subclass of braindecode's BaseDataset, which can consequently be used in
30
+ conjunction with the preprocessing and training pipelines of braindecode.
40
31
  """
41
- AWS_BUCKET = 's3://openneuro.org'
42
- def __init__(self, record, cache_dir, **kwargs):
32
+
33
+ AWS_BUCKET = "s3://openneuro.org"
34
+
35
+ def __init__(
36
+ self,
37
+ record: dict[str, Any],
38
+ cache_dir: str,
39
+ s3_bucket: str | None = None,
40
+ **kwargs,
41
+ ):
42
+ """Create a new EEGDashBaseDataset instance. Users do not usually need to call this
43
+ directly -- instead use the EEGDashDataset class to load a collection of these
44
+ recordings from a local BIDS folder or using a database query.
45
+
46
+ Parameters
47
+ ----------
48
+ record : dict
49
+ A fully resolved metadata record for the data to load.
50
+ cache_dir : str
51
+ A local directory where the data will be cached.
52
+ kwargs : dict
53
+ Additional keyword arguments to pass to the BaseDataset constructor.
54
+
55
+ """
43
56
  super().__init__(None, **kwargs)
44
57
  self.record = record
45
58
  self.cache_dir = Path(cache_dir)
46
59
  bids_kwargs = self.get_raw_bids_args()
47
- self.bidspath = BIDSPath(root=self.cache_dir / record['dataset'], datatype='eeg', suffix='eeg', **bids_kwargs)
48
- self.s3file = self.get_s3path(record['bidspath'])
49
- self.filecache = self.cache_dir / record['bidspath']
50
- self.bids_dependencies = record['bidsdependencies']
60
+
61
+ self.bidspath = BIDSPath(
62
+ root=self.cache_dir / record["dataset"],
63
+ datatype="eeg",
64
+ suffix="eeg",
65
+ **bids_kwargs,
66
+ )
67
+ self.s3_bucket = s3_bucket if s3_bucket else self.AWS_BUCKET
68
+ self.s3file = self.get_s3path(record["bidspath"])
69
+ self.filecache = self.cache_dir / record["bidspath"]
70
+ self.bids_dependencies = record["bidsdependencies"]
51
71
  self._raw = None
52
- # if os.path.exists(self.filecache):
53
- # self.raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
54
72
 
55
- def get_s3path(self, filepath):
56
- return f"{self.AWS_BUCKET}/{filepath}"
73
+ def get_s3path(self, filepath: str) -> str:
74
+ """Helper to form an AWS S3 URI for the given relative filepath."""
75
+ return f"{self.s3_bucket}/{filepath}"
57
76
 
58
- def _download_s3(self):
77
+ def _download_s3(self) -> None:
78
+ """Fetch the given data from its S3 location and cache it locally."""
59
79
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
60
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
80
+ filesystem = s3fs.S3FileSystem(
81
+ anon=True, client_kwargs={"region_name": "us-east-2"}
82
+ )
61
83
  filesystem.download(self.s3file, self.filecache)
62
84
  self.filenames = [self.filecache]
63
85
 
64
- def _download_dependencies(self):
65
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
86
+ def _download_dependencies(self) -> None:
87
+ """Download all BIDS dependency files (metadata files, recording sidecar files)
88
+ from S3 and cache them locally.
89
+ """
90
+ filesystem = s3fs.S3FileSystem(
91
+ anon=True, client_kwargs={"region_name": "us-east-2"}
92
+ )
66
93
  for dep in self.bids_dependencies:
67
94
  s3path = self.get_s3path(dep)
68
95
  filepath = self.cache_dir / dep
69
96
  if not filepath.exists():
70
97
  filepath.parent.mkdir(parents=True, exist_ok=True)
71
- filesystem.download(s3path, filepath)
98
+ filesystem.download(s3path, filepath)
72
99
 
73
- def get_raw_bids_args(self):
74
- desired_fields = ['subject', 'session', 'task', 'run']
100
+ def get_raw_bids_args(self) -> dict[str, Any]:
101
+ """Helper to restrict the metadata record to the fields needed to locate a BIDS
102
+ recording.
103
+ """
104
+ desired_fields = ["subject", "session", "task", "run"]
75
105
  return {k: self.record[k] for k in desired_fields if self.record[k]}
76
106
 
77
- def check_and_get_raw(self):
78
- if not os.path.exists(self.filecache): # not preload
107
+ def check_and_get_raw(self) -> None:
108
+ """Download the S3 file and BIDS dependencies if not already cached."""
109
+ if not os.path.exists(self.filecache): # not preload
79
110
  if self.bids_dependencies:
80
111
  self._download_dependencies()
81
112
  self._download_s3()
82
113
  if self._raw is None:
83
114
  self._raw = mne_bids.read_raw_bids(self.bidspath, verbose=False)
84
115
 
85
- def __getitem__(self, index):
86
- # self.check_and_get_raw()
116
+ # === BaseDataset and PyTorch Dataset interface ===
87
117
 
118
+ def __getitem__(self, index):
119
+ """Main function to access a sample from the dataset."""
88
120
  X = self.raw[:, index][0]
89
121
  y = None
90
122
  if self.target_name is not None:
@@ -94,15 +126,21 @@ class EEGDashBaseDataset(BaseDataset):
94
126
  if self.transform is not None:
95
127
  X = self.transform(X)
96
128
  return X, y
97
-
98
- def __len__(self):
129
+
130
+ def __len__(self) -> int:
131
+ """Return the number of samples in the dataset."""
99
132
  if self._raw is None:
100
- return int(self.record['ntimes'] * self.record['sampling_frequency'])
133
+ # FIXME: this is a bit strange and should definitely not change as a side effect
134
+ # of accessing the data (which it will, since ntimes is the actual length but rounded down)
135
+ return int(self.record["ntimes"] * self.record["sampling_frequency"])
101
136
  else:
102
137
  return len(self._raw)
103
138
 
104
139
  @property
105
140
  def raw(self):
141
+ """Return the MNE Raw object for this recording. This will perform the actual
142
+ retrieval if not yet done so.
143
+ """
106
144
  if self._raw is None:
107
145
  self.check_and_get_raw()
108
146
  return self._raw
@@ -111,59 +149,55 @@ class EEGDashBaseDataset(BaseDataset):
111
149
  def raw(self, raw):
112
150
  self._raw = raw
113
151
 
152
+
114
153
  class EEGDashBaseRaw(BaseRaw):
115
- r"""MNE Raw object from EEG-Dash connection with Openneuro S3 file.
154
+ """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
155
+ (when _read_segment is called) and caches it locally. Currently for internal use.
116
156
 
117
157
  Parameters
118
158
  ----------
119
159
  input_fname : path-like
120
160
  Path to the S3 file
121
- eog : list | tuple | 'auto'
122
- Names or indices of channels that should be designated EOG channels.
123
- If 'auto', the channel names containing ``EOG`` or ``EYE`` are used.
124
- Defaults to empty tuple.
125
- %(preload)s
126
- Note that preload=False will be effective only if the data is stored
127
- in a separate binary file.
128
- %(uint16_codec)s
129
- %(montage_units)s
130
- %(verbose)s
161
+ metadata : dict
162
+ The metadata record for the recording (e.g., from the database).
163
+ preload : bool
164
+ Whether to pre-loaded the data before the first access.
165
+ cache_dir : str
166
+ Local path under which the data will be cached.
167
+ bids_dependencies : list
168
+ List of additional BIDS metadata files that should be downloaded and cached
169
+ alongside the main recording file.
170
+ verbose : str | int | None
171
+ Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
131
172
 
132
173
  See Also
133
174
  --------
134
175
  mne.io.Raw : Documentation of attributes and methods.
135
176
 
136
- Notes
137
- -----
138
- .. versionadded:: 0.11.0
139
177
  """
140
178
 
141
- AWS_BUCKET = 's3://openneuro.org'
179
+ AWS_BUCKET = "s3://openneuro.org"
180
+
142
181
  def __init__(
143
182
  self,
144
- input_fname,
145
- metadata,
146
- eog=(),
147
- preload=False,
183
+ input_fname: str,
184
+ metadata: dict[str, Any],
185
+ preload: bool = False,
148
186
  *,
149
- cache_dir='./.eegdash_cache',
150
- bids_dependencies:list = [],
151
- uint16_codec=None,
152
- montage_units="auto",
153
- verbose=None,
187
+ cache_dir: str = "./.eegdash_cache",
188
+ bids_dependencies: list[str] = [],
189
+ verbose: Any = None,
154
190
  ):
155
- '''
156
- Get to work with S3 endpoint first, no caching
157
- '''
191
+ """Get to work with S3 endpoint first, no caching"""
158
192
  # Create a simple RawArray
159
- sfreq = metadata['sfreq'] # Sampling frequency
160
- n_times = metadata['n_times']
161
- ch_names = metadata['ch_names']
193
+ sfreq = metadata["sfreq"] # Sampling frequency
194
+ n_times = metadata["n_times"]
195
+ ch_names = metadata["ch_names"]
162
196
  ch_types = []
163
- for ch in metadata['ch_types']:
197
+ for ch in metadata["ch_types"]:
164
198
  chtype = ch.lower()
165
- if chtype == 'heog' or chtype == 'veog':
166
- chtype = 'eog'
199
+ if chtype == "heog" or chtype == "veog":
200
+ chtype = "eog"
167
201
  ch_types.append(chtype)
168
202
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
169
203
  self.s3file = self.get_s3path(input_fname)
@@ -178,7 +212,7 @@ class EEGDashBaseRaw(BaseRaw):
178
212
  super().__init__(
179
213
  info,
180
214
  preload,
181
- last_samps=[n_times-1],
215
+ last_samps=[n_times - 1],
182
216
  orig_format="single",
183
217
  verbose=verbose,
184
218
  )
@@ -188,12 +222,16 @@ class EEGDashBaseRaw(BaseRaw):
188
222
 
189
223
  def _download_s3(self):
190
224
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
191
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
225
+ filesystem = s3fs.S3FileSystem(
226
+ anon=True, client_kwargs={"region_name": "us-east-2"}
227
+ )
192
228
  filesystem.download(self.s3file, self.filecache)
193
229
  self.filenames = [self.filecache]
194
230
 
195
231
  def _download_dependencies(self):
196
- filesystem = s3fs.S3FileSystem(anon=True, client_kwargs={'region_name': 'us-east-2'})
232
+ filesystem = s3fs.S3FileSystem(
233
+ anon=True, client_kwargs={"region_name": "us-east-2"}
234
+ )
197
235
  for dep in self.bids_dependencies:
198
236
  s3path = self.get_s3path(dep)
199
237
  filepath = self.cache_dir / dep
@@ -204,34 +242,56 @@ class EEGDashBaseRaw(BaseRaw):
204
242
  def _read_segment(
205
243
  self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
206
244
  ):
207
- if not os.path.exists(self.filecache): # not preload
245
+ if not os.path.exists(self.filecache): # not preload
208
246
  if self.bids_dependencies:
209
247
  self._download_dependencies()
210
248
  self._download_s3()
211
- else: # not preload and file is not cached
249
+ else: # not preload and file is not cached
212
250
  self.filenames = [self.filecache]
213
251
  return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
214
-
252
+
215
253
  def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
216
254
  """Read a chunk of data from the file."""
217
255
  _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype="<f4")
218
256
 
219
257
 
220
- class EEGBIDSDataset():
221
- ALLOWED_FILE_FORMAT = ['eeglab', 'brainvision', 'biosemi', 'european']
258
+ class EEGBIDSDataset:
259
+ """A one-stop shop interface to a local BIDS dataset containing EEG recordings.
260
+
261
+ This is mainly tailored to the needs of EEGDash application and is used to centralize
262
+ interactions with the BIDS dataset, such as parsing the metadata.
263
+
264
+ Parameters
265
+ ----------
266
+ data_dir : str | Path
267
+ The path to the local BIDS dataset directory.
268
+ dataset : str
269
+ A name for the dataset.
270
+
271
+ """
272
+
273
+ ALLOWED_FILE_FORMAT = ["eeglab", "brainvision", "biosemi", "european"]
222
274
  RAW_EXTENSIONS = {
223
- '.set': ['.set', '.fdt'], # eeglab
224
- '.edf': ['.edf'], # european
225
- '.vhdr': ['.eeg', '.vhdr', '.vmrk', '.dat', '.raw'], # brainvision
226
- '.bdf': ['.bdf'], # biosemi
227
- }
228
- METADATA_FILE_EXTENSIONS = ['eeg.json', 'channels.tsv', 'electrodes.tsv', 'events.tsv', 'events.json']
229
- def __init__(self,
230
- data_dir=None, # location of bids dataset
231
- dataset='', # dataset name
232
- ):
275
+ ".set": [".set", ".fdt"], # eeglab
276
+ ".edf": [".edf"], # european
277
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
278
+ ".bdf": [".bdf"], # biosemi
279
+ }
280
+ METADATA_FILE_EXTENSIONS = [
281
+ "eeg.json",
282
+ "channels.tsv",
283
+ "electrodes.tsv",
284
+ "events.tsv",
285
+ "events.json",
286
+ ]
287
+
288
+ def __init__(
289
+ self,
290
+ data_dir=None, # location of bids dataset
291
+ dataset="", # dataset name
292
+ ):
233
293
  if data_dir is None or not os.path.exists(data_dir):
234
- raise ValueError('data_dir must be specified and must exist')
294
+ raise ValueError("data_dir must be specified and must exist")
235
295
  self.bidsdir = Path(data_dir)
236
296
  self.dataset = dataset
237
297
  assert str(self.bidsdir).endswith(self.dataset)
@@ -239,73 +299,87 @@ class EEGBIDSDataset():
239
299
 
240
300
  # get all recording files in the bids directory
241
301
  self.files = self.get_recordings(self.layout)
242
- assert len(self.files) > 0, ValueError('Unable to construct EEG dataset. No EEG recordings found.')
243
- assert self.check_eeg_dataset(), ValueError('Dataset is not an EEG dataset.')
244
- # temp_dir = (Path().resolve() / 'data')
245
- # if not os.path.exists(temp_dir):
246
- # os.mkdir(temp_dir)
247
- # if not os.path.exists(temp_dir / f'{dataset}_files.npy'):
248
- # self.files = self.get_files_with_extension_parallel(self.bidsdir, extension=self.RAW_EXTENSION[self.raw_format])
249
- # np.save(temp_dir / f'{dataset}_files.npy', self.files)
250
- # else:
251
- # self.files = np.load(temp_dir / f'{dataset}_files.npy', allow_pickle=True)
252
-
253
- def check_eeg_dataset(self):
254
- return self.get_bids_file_attribute('modality', self.files[0]).lower() == 'eeg'
255
-
256
- def get_recordings(self, layout:BIDSLayout):
302
+ assert len(self.files) > 0, ValueError(
303
+ "Unable to construct EEG dataset. No EEG recordings found."
304
+ )
305
+ assert self.check_eeg_dataset(), ValueError("Dataset is not an EEG dataset.")
306
+
307
+ def check_eeg_dataset(self) -> bool:
308
+ """Check if the dataset is EEG."""
309
+ return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
310
+
311
+ def get_recordings(self, layout: BIDSLayout) -> list[str]:
312
+ """Get a list of all EEG recording files in the BIDS layout."""
257
313
  files = []
258
314
  for ext, exts in self.RAW_EXTENSIONS.items():
259
- files = layout.get(extension=ext, return_type='filename')
315
+ files = layout.get(extension=ext, return_type="filename")
260
316
  if files:
261
- break
317
+ break
262
318
  return files
263
319
 
264
- def get_relative_bidspath(self, filename):
265
- bids_parent_dir = self.bidsdir.parent
320
+ def get_relative_bidspath(self, filename: str) -> str:
321
+ """Make the given file path relative to the BIDS directory."""
322
+ bids_parent_dir = self.bidsdir.parent.absolute()
266
323
  return str(Path(filename).relative_to(bids_parent_dir))
267
324
 
268
- def get_property_from_filename(self, property, filename):
325
+ def get_property_from_filename(self, property: str, filename: str) -> str:
326
+ """Parse a property out of a BIDS-compliant filename. Returns an empty string
327
+ if not found.
328
+ """
269
329
  import platform
330
+
270
331
  if platform.system() == "Windows":
271
- lookup = re.search(rf'{property}-(.*?)[_\\]', filename)
332
+ lookup = re.search(rf"{property}-(.*?)[_\\]", filename)
272
333
  else:
273
- lookup = re.search(rf'{property}-(.*?)[_\/]', filename)
274
- return lookup.group(1) if lookup else ''
275
-
276
- def merge_json_inheritance(self, json_files):
277
- '''
278
- Merge list of json files found by get_bids_file_inheritance,
279
- expecting the order (from left to right) is from lowest level to highest level,
280
- and return a merged dictionary
281
- '''
334
+ lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
335
+ return lookup.group(1) if lookup else ""
336
+
337
+ def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
338
+ """Internal helper to merge list of json files found by get_bids_file_inheritance,
339
+ expecting the order (from left to right) is from lowest
340
+ level to highest level, and return a merged dictionary
341
+ """
282
342
  json_files.reverse()
283
343
  json_dict = {}
284
344
  for f in json_files:
285
- json_dict.update(json.load(open(f)))
345
+ json_dict.update(json.load(open(f))) # FIXME: should close file
286
346
  return json_dict
287
347
 
288
- def get_bids_file_inheritance(self, path, basename, extension):
289
- '''
290
- Get all files with given extension that applies to the basename file
291
- following the BIDS inheritance principle in the order of lowest level first
292
- @param
293
- basename: bids file basename without _eeg.set extension for example
294
- extension: e.g. channels.tsv
295
- '''
296
- top_level_files = ['README', 'dataset_description.json', 'participants.tsv']
348
+ def get_bids_file_inheritance(
349
+ self, path: str | Path, basename: str, extension: str
350
+ ) -> list[Path]:
351
+ """Get all file paths that apply to the basename file in the specified directory
352
+ and that end with the specified suffix, recursively searching parent directories
353
+ (following the BIDS inheritance principle in the order of lowest level first).
354
+
355
+ Parameters
356
+ ----------
357
+ path : str | Path
358
+ The directory path to search for files.
359
+ basename : str
360
+ BIDS file basename without _eeg.set extension for example
361
+ extension : str
362
+ Only consider files that end with the specified suffix; e.g. channels.tsv
363
+
364
+ Returns
365
+ -------
366
+ list[Path]
367
+ A list of file paths that match the given basename and extension.
368
+
369
+ """
370
+ top_level_files = ["README", "dataset_description.json", "participants.tsv"]
297
371
  bids_files = []
298
372
 
299
373
  # check if path is str object
300
374
  if isinstance(path, str):
301
375
  path = Path(path)
302
376
  if not path.exists:
303
- raise ValueError('path {path} does not exist')
377
+ raise ValueError("path {path} does not exist")
304
378
 
305
379
  # check if file is in current path
306
380
  for file in os.listdir(path):
307
381
  # target_file = path / f"{cur_file_basename}_{extension}"
308
- if os.path.isfile(path/file):
382
+ if os.path.isfile(path / file):
309
383
  # check if file has extension extension
310
384
  # check if file basename has extension
311
385
  if file.endswith(extension):
@@ -317,38 +391,54 @@ class EEGBIDSDataset():
317
391
  return bids_files
318
392
  else:
319
393
  # call get_bids_file_inheritance recursively with parent directory
320
- bids_files.extend(self.get_bids_file_inheritance(path.parent, basename, extension))
394
+ bids_files.extend(
395
+ self.get_bids_file_inheritance(path.parent, basename, extension)
396
+ )
321
397
  return bids_files
322
398
 
323
- def get_bids_metadata_files(self, filepath, metadata_file_extension):
324
- """
325
- (Wrapper for self.get_bids_file_inheritance)
326
- Get all BIDS metadata files that are associated with the given filepath, following the BIDS inheritance principle.
327
-
328
- Args:
329
- filepath (str or Path): The filepath to get the associated metadata files for.
330
- metadata_files_extensions (list): A list of file extensions to search for metadata files.
331
-
332
- Returns:
333
- list: A list of filepaths for all the associated metadata files
399
+ def get_bids_metadata_files(
400
+ self, filepath: str | Path, metadata_file_extension: list[str]
401
+ ) -> list[Path]:
402
+ """Retrieve all metadata file paths that apply to a given data file path and that
403
+ end with a specific suffix (following the BIDS inheritance principle).
404
+
405
+ Parameters
406
+ ----------
407
+ filepath: str | Path
408
+ The filepath to get the associated metadata files for.
409
+ metadata_file_extension : str
410
+ Consider only metadata files that end with the specified suffix,
411
+ e.g., channels.tsv or eeg.json
412
+
413
+ Returns
414
+ -------
415
+ list[Path]:
416
+ A list of filepaths for all matching metadata files
417
+
334
418
  """
335
419
  if isinstance(filepath, str):
336
420
  filepath = Path(filepath)
337
421
  if not filepath.exists:
338
- raise ValueError('filepath {filepath} does not exist')
422
+ raise ValueError("filepath {filepath} does not exist")
339
423
  path, filename = os.path.split(filepath)
340
- basename = filename[:filename.rfind('_')]
424
+ basename = filename[: filename.rfind("_")]
341
425
  # metadata files
342
- meta_files = self.get_bids_file_inheritance(path, basename, metadata_file_extension)
426
+ meta_files = self.get_bids_file_inheritance(
427
+ path, basename, metadata_file_extension
428
+ )
343
429
  return meta_files
344
-
345
- def scan_directory(self, directory, extension):
430
+
431
+ def scan_directory(self, directory: str, extension: str) -> list[Path]:
432
+ """Return a list of file paths that end with the given extension in the specified
433
+ directory. Ignores certain special directories like .git, .datalad, derivatives,
434
+ and code.
435
+ """
346
436
  result_files = []
347
- directory_to_ignore = ['.git', '.datalad', 'derivatives', 'code']
437
+ directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
348
438
  with os.scandir(directory) as entries:
349
439
  for entry in entries:
350
440
  if entry.is_file() and entry.name.endswith(extension):
351
- print('Adding ', entry.path)
441
+ print("Adding ", entry.path)
352
442
  result_files.append(entry.path)
353
443
  elif entry.is_dir():
354
444
  # check that entry path doesn't contain any name in ignore list
@@ -356,18 +446,41 @@ class EEGBIDSDataset():
356
446
  result_files.append(entry.path) # Add directory to scan later
357
447
  return result_files
358
448
 
359
- def get_files_with_extension_parallel(self, directory, extension='.set', max_workers=-1):
449
+ def get_files_with_extension_parallel(
450
+ self, directory: str, extension: str = ".set", max_workers: int = -1
451
+ ) -> list[Path]:
452
+ """Efficiently scan a directory and its subdirectories for files that end with
453
+ the given extension.
454
+
455
+ Parameters
456
+ ----------
457
+ directory : str
458
+ The root directory to scan for files.
459
+ extension : str
460
+ Only consider files that end with this suffix, e.g. '.set'.
461
+ max_workers : int
462
+ Optionally specify the maximum number of worker threads to use for parallel scanning.
463
+ Defaults to all available CPU cores if set to -1.
464
+
465
+ Returns
466
+ -------
467
+ list[Path]:
468
+ A list of filepaths for all matching metadata files
469
+
470
+ """
360
471
  result_files = []
361
472
  dirs_to_scan = [directory]
362
473
 
363
474
  # Use joblib.Parallel and delayed to parallelize directory scanning
364
475
  while dirs_to_scan:
365
- print(f"Scanning {len(dirs_to_scan)} directories...", dirs_to_scan)
476
+ logger.info(
477
+ f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
478
+ )
366
479
  # Run the scan_directory function in parallel across directories
367
480
  results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
368
481
  delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
369
482
  )
370
-
483
+
371
484
  # Reset the directories to scan and process the results
372
485
  dirs_to_scan = []
373
486
  for res in results:
@@ -376,14 +489,20 @@ class EEGBIDSDataset():
376
489
  dirs_to_scan.append(path) # Queue up subdirectories to scan
377
490
  else:
378
491
  result_files.append(path) # Add files to the final result
379
- print(f"Current number of files: {len(result_files)}")
492
+ logger.info(f"Found {len(result_files)} files.")
380
493
 
381
494
  return result_files
382
495
 
383
- def load_and_preprocess_raw(self, raw_file, preprocess=False):
384
- print(f"Loading {raw_file}")
385
- EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose='error')
386
-
496
+ def load_and_preprocess_raw(
497
+ self, raw_file: str, preprocess: bool = False
498
+ ) -> np.ndarray:
499
+ """Utility function to load a raw data file with MNE and apply some simple
500
+ (hardcoded) preprocessing and return as a numpy array. Not meant for purposes
501
+ other than testing or debugging.
502
+ """
503
+ logger.info(f"Loading raw data from {raw_file}")
504
+ EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
505
+
387
506
  if preprocess:
388
507
  # highpass filter
389
508
  EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
@@ -391,33 +510,35 @@ class EEGBIDSDataset():
391
510
  EEG = EEG.notch_filter(freqs=(60), verbose=False)
392
511
  # bring to common sampling rate
393
512
  sfreq = 128
394
- if EEG.info['sfreq'] != sfreq:
513
+ if EEG.info["sfreq"] != sfreq:
395
514
  EEG = EEG.resample(sfreq)
396
- # # normalize data to zero mean and unit variance
397
- # scalar = preprocessing.StandardScaler()
398
- # mat_data = scalar.fit_transform(mat_data.T).T # scalar normalize for each feature and expects shape data x features
399
515
 
400
516
  mat_data = EEG.get_data()
401
517
 
402
518
  if len(mat_data.shape) > 2:
403
- raise ValueError('Expect raw data to be CxT dimension')
519
+ raise ValueError("Expect raw data to be CxT dimension")
404
520
  return mat_data
405
-
406
- def get_files(self):
521
+
522
+ def get_files(self) -> list[Path]:
523
+ """Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
407
524
  return self.files
408
-
409
- def resolve_bids_json(self, json_files: list):
410
- """
411
- Resolve the BIDS JSON files and return a dictionary of the resolved values.
412
- Args:
413
- json_files (list): A list of JSON files to resolve in order of leaf level first
414
525
 
415
- Returns:
526
+ def resolve_bids_json(self, json_files: list[str]) -> dict:
527
+ """Resolve the BIDS JSON files and return a dictionary of the resolved values.
528
+
529
+ Parameters
530
+ ----------
531
+ json_files : list
532
+ A list of JSON file paths to resolve in order of leaf level first.
533
+
534
+ Returns
535
+ -------
416
536
  dict: A dictionary of the resolved values.
537
+
417
538
  """
418
539
  if len(json_files) == 0:
419
- raise ValueError('No JSON files provided')
420
- json_files.reverse() # TODO undeterministic
540
+ raise ValueError("No JSON files provided")
541
+ json_files.reverse() # TODO undeterministic
421
542
 
422
543
  json_dict = {}
423
544
  for json_file in json_files:
@@ -425,56 +546,78 @@ class EEGBIDSDataset():
425
546
  json_dict.update(json.load(f))
426
547
  return json_dict
427
548
 
428
- def get_bids_file_attribute(self, attribute, data_filepath):
549
+ def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
550
+ """Retrieve a specific attribute from the BIDS file metadata applicable
551
+ to the provided recording file path.
552
+ """
429
553
  entities = self.layout.parse_file_entities(data_filepath)
430
554
  bidsfile = self.layout.get(**entities)[0]
431
- attributes = bidsfile.get_entities(metadata='all')
555
+ attributes = bidsfile.get_entities(metadata="all")
432
556
  attribute_mapping = {
433
- 'sfreq': 'SamplingFrequency',
434
- 'modality': 'datatype',
435
- 'task': 'task',
436
- 'session': 'session',
437
- 'run': 'run',
438
- 'subject': 'subject',
439
- 'ntimes': 'RecordingDuration',
440
- 'nchans': 'EEGChannelCount'
557
+ "sfreq": "SamplingFrequency",
558
+ "modality": "datatype",
559
+ "task": "task",
560
+ "session": "session",
561
+ "run": "run",
562
+ "subject": "subject",
563
+ "ntimes": "RecordingDuration",
564
+ "nchans": "EEGChannelCount",
441
565
  }
442
566
  attribute_value = attributes.get(attribute_mapping.get(attribute), None)
443
567
  return attribute_value
444
568
 
445
- def channel_labels(self, data_filepath):
446
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
447
- return channels_tsv['name'].tolist()
448
-
449
- def channel_types(self, data_filepath):
450
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
451
- return channels_tsv['type'].tolist()
452
-
453
- def num_times(self, data_filepath):
454
- eeg_jsons = self.get_bids_metadata_files(data_filepath, 'eeg.json')
569
+ def channel_labels(self, data_filepath: str) -> list[str]:
570
+ """Get a list of channel labels for the given data file path."""
571
+ channels_tsv = pd.read_csv(
572
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
573
+ )
574
+ return channels_tsv["name"].tolist()
575
+
576
+ def channel_types(self, data_filepath: str) -> list[str]:
577
+ """Get a list of channel types for the given data file path."""
578
+ channels_tsv = pd.read_csv(
579
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
580
+ )
581
+ return channels_tsv["type"].tolist()
582
+
583
+ def num_times(self, data_filepath: str) -> int:
584
+ """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
585
+ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
455
586
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
456
- return int(eeg_json_dict['SamplingFrequency'] * eeg_json_dict['RecordingDuration'])
457
-
458
- def subject_participant_tsv(self, data_filepath):
459
- '''Get participants_tsv info of a subject based on filepath'''
460
- participants_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'participants.tsv')[0], sep='\t')
587
+ return int(
588
+ eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
589
+ )
590
+
591
+ def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
592
+ """Get BIDS participants.tsv record for the subject to which the given file
593
+ path corresponds, as a dictionary.
594
+ """
595
+ participants_tsv = pd.read_csv(
596
+ self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
597
+ )
461
598
  # if participants_tsv is not empty
462
599
  if participants_tsv.empty:
463
600
  return {}
464
601
  # set 'participant_id' as index
465
- participants_tsv.set_index('participant_id', inplace=True)
602
+ participants_tsv.set_index("participant_id", inplace=True)
466
603
  subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
467
604
  return participants_tsv.loc[subject].to_dict()
468
-
469
- def eeg_json(self, data_filepath):
470
- eeg_jsons = self.get_bids_metadata_files(data_filepath, 'eeg.json')
605
+
606
+ def eeg_json(self, data_filepath: str) -> dict[str, Any]:
607
+ """Get BIDS eeg.json metadata for the given data file path."""
608
+ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
471
609
  eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
472
610
  return eeg_json_dict
473
-
474
- def channel_tsv(self, data_filepath):
475
- channels_tsv = pd.read_csv(self.get_bids_metadata_files(data_filepath, 'channels.tsv')[0], sep='\t')
611
+
612
+ def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
613
+ """Get BIDS channels.tsv metadata for the given data file path, as a dictionary
614
+ of lists and/or single values.
615
+ """
616
+ channels_tsv = pd.read_csv(
617
+ self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
618
+ )
476
619
  channel_tsv = channels_tsv.to_dict()
477
620
  # 'name' and 'type' now have a dictionary of index-value. Convert them to list
478
- for list_field in ['name', 'type', 'units']:
621
+ for list_field in ["name", "type", "units"]:
479
622
  channel_tsv[list_field] = list(channel_tsv[list_field].values())
480
- return channel_tsv
623
+ return channel_tsv