eegdash 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,103 @@
1
+ import numbers
2
+ import numpy as np
3
+ from scipy import stats
4
+
5
+ from ..decorators import univariate_feature
6
+
7
+
8
+ __all__ = [
9
+ "signal_mean",
10
+ "signal_variance",
11
+ "signal_skewness",
12
+ "signal_kurtosis",
13
+ "signal_std",
14
+ "signal_root_mean_square",
15
+ "signal_peak_to_peak",
16
+ "signal_quantile",
17
+ "signal_zero_crossings",
18
+ "signal_line_length",
19
+ "signal_hjorth_activity",
20
+ "signal_hjorth_mobility",
21
+ "signal_hjorth_complexity",
22
+ "signal_decorrelation_time",
23
+ ]
24
+
25
+
26
+ @univariate_feature
27
+ def signal_mean(x):
28
+ return x.mean(axis=-1)
29
+
30
+
31
+ @univariate_feature
32
+ def signal_variance(x, **kwargs):
33
+ return x.var(axis=-1, **kwargs)
34
+
35
+
36
+ @univariate_feature
37
+ def signal_std(x, **kwargs):
38
+ return x.std(axis=-1, **kwargs)
39
+
40
+
41
+ @univariate_feature
42
+ def signal_skewness(x, **kwargs):
43
+ return stats.skew(x, axis=x.ndim - 1, **kwargs)
44
+
45
+
46
+ @univariate_feature
47
+ def signal_kurtosis(x, **kwargs):
48
+ return stats.kurtosis(x, axis=x.ndim - 1, **kwargs)
49
+
50
+
51
+ @univariate_feature
52
+ def signal_root_mean_square(x):
53
+ return np.sqrt(np.power(x, 2).mean(axis=-1))
54
+
55
+
56
+ @univariate_feature
57
+ def signal_peak_to_peak(x, **kwargs):
58
+ return np.ptp(x, axis=-1, **kwargs)
59
+
60
+
61
+ @univariate_feature
62
+ def signal_quantile(x, q: numbers.Number = 0.5, **kwargs):
63
+ return np.quantile(x, q=q, axis=-1, **kwargs)
64
+
65
+
66
+ @univariate_feature
67
+ def signal_line_length(x):
68
+ return np.abs(np.diff(x, axis=-1)).mean(axis=-1)
69
+
70
+
71
+ @univariate_feature
72
+ def signal_zero_crossings(x, threshold=1e-15):
73
+ zero_ind = np.logical_and(x > -threshold, x < threshold)
74
+ zero_cross = np.diff(zero_ind, axis=-1).astype(int).sum(axis=-1)
75
+ y = x.copy()
76
+ y[zero_ind] = 0
77
+ zero_cross += np.sum(np.signbit(y[..., :-1]) != np.signbit(y[..., 1:]), axis=-1)
78
+ return zero_cross
79
+
80
+
81
+ @univariate_feature
82
+ def signal_hjorth_mobility(x):
83
+ return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1)
84
+
85
+
86
+ @univariate_feature
87
+ def signal_hjorth_complexity(x):
88
+ return np.diff(x, 2, axis=-1).std(axis=-1) / x.std(axis=-1)
89
+
90
+
91
+ @univariate_feature
92
+ def signal_decorrelation_time(x, fs=1):
93
+ f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1)
94
+ ac = np.fft.ifft(f.real**2 + f.imag**2, axis=-1)[..., : x.shape[-1] // 2]
95
+ dct = np.empty(x.shape[:-1])
96
+ for i in np.ndindex(x.shape[:-1]):
97
+ dct[i] = np.searchsorted(ac[i] <= 0, True)
98
+ return dct / fs
99
+
100
+
101
+ # ================================= Aliases =================================
102
+
103
+ signal_hjorth_activity = signal_variance
@@ -0,0 +1,134 @@
1
+ import numpy as np
2
+ import numba as nb
3
+ from scipy.signal import welch
4
+
5
+ from ..extractors import FeatureExtractor
6
+ from ..decorators import FeaturePredecessor, univariate_feature
7
+
8
+
9
+ __all__ = [
10
+ "SpectralFeatureExtractor",
11
+ "NormalizedSpectralFeatureExtractor",
12
+ "DBSpectralFeatureExtractor",
13
+ "spectral_root_total_power",
14
+ "spectral_moment",
15
+ "spectral_entropy",
16
+ "spectral_edge",
17
+ "spectral_slope",
18
+ "spectral_bands_power",
19
+ "spectral_hjorth_activity",
20
+ "spectral_hjorth_mobility",
21
+ "spectral_hjorth_complexity",
22
+ ]
23
+
24
+
25
+ class SpectralFeatureExtractor(FeatureExtractor):
26
+ def preprocess(self, x, **kwargs):
27
+ f_min = kwargs.pop("f_min") if "f_min" in kwargs else None
28
+ f_max = kwargs.pop("f_max") if "f_max" in kwargs else None
29
+ kwargs["axis"] = -1
30
+ f, p = welch(x, **kwargs)
31
+ if f_min is not None or f_max is not None:
32
+ f_min_idx = f > f_min if f_min is not None else True
33
+ f_max_idx = f < f_max if f_max is not None else True
34
+ idx = np.logical_and(f_min_idx, f_max_idx)
35
+ f = f[idx]
36
+ p = p[..., idx]
37
+ return f, p
38
+
39
+
40
+ @FeaturePredecessor(SpectralFeatureExtractor)
41
+ class NormalizedSpectralFeatureExtractor(FeatureExtractor):
42
+ def preprocess(self, *x):
43
+ return (*x[:-1], x[-1] / x[-1].sum(axis=-1, keepdims=True))
44
+
45
+
46
+ @FeaturePredecessor(SpectralFeatureExtractor)
47
+ class DBSpectralFeatureExtractor(FeatureExtractor):
48
+ def preprocess(self, *x, eps=1e-15):
49
+ return (*x[:-1], 10 * np.log10(x[-1] + eps))
50
+
51
+
52
+ @FeaturePredecessor(SpectralFeatureExtractor)
53
+ @univariate_feature
54
+ def spectral_root_total_power(f, p):
55
+ return np.sqrt(p.sum(axis=-1))
56
+
57
+
58
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
59
+ @univariate_feature
60
+ def spectral_moment(f, p):
61
+ return np.sum(f * p, axis=-1)
62
+
63
+
64
+ @FeaturePredecessor(SpectralFeatureExtractor)
65
+ @univariate_feature
66
+ def spectral_hjorth_activity(f, p):
67
+ return np.sum(p, axis=-1)
68
+
69
+
70
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
71
+ @univariate_feature
72
+ def spectral_hjorth_mobility(f, p):
73
+ return np.sqrt(np.sum(np.power(f, 2) * p, axis=-1))
74
+
75
+
76
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
77
+ @univariate_feature
78
+ def spectral_hjorth_complexity(f, p):
79
+ return np.sqrt(np.sum(np.power(f, 4) * p, axis=-1))
80
+
81
+
82
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
83
+ @univariate_feature
84
+ def spectral_entropy(f, p):
85
+ idx = p > 0
86
+ plogp = np.zeros_like(p)
87
+ plogp[idx] = p[idx] * np.log(p[idx])
88
+ return -np.sum(plogp, axis=-1)
89
+
90
+
91
+ @FeaturePredecessor(NormalizedSpectralFeatureExtractor)
92
+ @univariate_feature
93
+ @nb.njit(cache=True, fastmath=True)
94
+ def spectral_edge(f, p, edge=0.9):
95
+ se = np.empty(p.shape[:-1])
96
+ for i in np.ndindex(p.shape[:-1]):
97
+ se[i] = f[np.searchsorted(np.cumsum(p[i]), edge)]
98
+ return se
99
+
100
+
101
+ @FeaturePredecessor(DBSpectralFeatureExtractor)
102
+ @univariate_feature
103
+ def spectral_slope(f, p):
104
+ log_f = np.vstack((np.log(f), np.ones(f.shape[0]))).T
105
+ r = np.linalg.lstsq(log_f, p.reshape(-1, p.shape[-1]).T)[0]
106
+ r = r.reshape(2, *p.shape[:-1])
107
+ return {"exp": r[0], "int": r[1]}
108
+
109
+
110
+ @FeaturePredecessor(
111
+ SpectralFeatureExtractor,
112
+ NormalizedSpectralFeatureExtractor,
113
+ DBSpectralFeatureExtractor,
114
+ )
115
+ @univariate_feature
116
+ def spectral_bands_power(
117
+ f,
118
+ p,
119
+ bands={
120
+ "delta": (1, 4.5),
121
+ "theta": (4.5, 8),
122
+ "alpha": (8, 12),
123
+ "beta": (12, 30),
124
+ },
125
+ ):
126
+ bands_power = dict()
127
+ for k, v in bands.items():
128
+ assert isinstance(k, str)
129
+ assert isinstance(v, tuple)
130
+ assert len(v) == 2
131
+ mask = np.logical_and(f > v[0], f < v[1])
132
+ power = p[..., mask].sum(axis=-1)
133
+ bands_power[k] = power
134
+ return bands_power
@@ -0,0 +1,87 @@
1
+ """
2
+ Convenience functions for storing and loading of features datasets.
3
+
4
+ see also: https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
5
+ """
6
+
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import pandas as pd
11
+ from joblib import Parallel, delayed
12
+
13
+ from mne.io import read_info
14
+ from braindecode.datautil.serialization import _load_kwargs_json
15
+
16
+ from .datasets import (
17
+ FeaturesDataset,
18
+ FeaturesConcatDataset,
19
+ )
20
+
21
+
22
+ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
23
+ """Load a stored FeaturesConcatDataset of FeaturesDatasets from files.
24
+
25
+ Parameters
26
+ ----------
27
+ path: str | pathlib.Path
28
+ Path to the directory of the .fif / -epo.fif and .json files.
29
+ ids_to_load: list of int | None
30
+ Ids of specific files to load.
31
+ n_jobs: int
32
+ Number of jobs to be used to read files in parallel.
33
+
34
+ Returns
35
+ -------
36
+ concat_dataset: FeaturesConcatDataset of FeaturesDatasets
37
+ """
38
+ # Make sure we always work with a pathlib.Path
39
+ path = Path(path)
40
+
41
+ # else we have a dataset saved in the new way with subdirectories in path
42
+ # for every dataset with description.json and -feat.parquet,
43
+ # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
44
+ # window_preproc_kwargs.json, features_kwargs.json
45
+ if ids_to_load is None:
46
+ ids_to_load = [p.name for p in path.iterdir()]
47
+ ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
48
+ ids_to_load = [str(i) for i in ids_to_load]
49
+
50
+ datasets = Parallel(n_jobs)(delayed(_load_parallel)(path, i) for i in ids_to_load)
51
+ return FeaturesConcatDataset(datasets)
52
+
53
+
54
+ def _load_parallel(path, i):
55
+ sub_dir = path / i
56
+
57
+ parquet_name_pattern = "{}-feat.parquet"
58
+ parquet_file_name = parquet_name_pattern.format(i)
59
+ parquet_file_path = sub_dir / parquet_file_name
60
+
61
+ features = pd.read_parquet(parquet_file_path)
62
+
63
+ description_file_path = sub_dir / "description.json"
64
+ description = pd.read_json(description_file_path, typ="series")
65
+
66
+ raw_info_file_path = sub_dir / "raw-info.fif"
67
+ raw_info = None
68
+ if raw_info_file_path.exists():
69
+ raw_info = read_info(raw_info_file_path)
70
+
71
+ raw_preproc_kwargs = _load_kwargs_json("raw_preproc_kwargs", sub_dir)
72
+ window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
73
+ window_preproc_kwargs = _load_kwargs_json("window_preproc_kwargs", sub_dir)
74
+ features_kwargs = _load_kwargs_json("features_kwargs", sub_dir)
75
+ metadata = pd.read_pickle(path / i / "metadata_df.pkl")
76
+
77
+ dataset = FeaturesDataset(
78
+ features,
79
+ metadata=metadata,
80
+ description=description,
81
+ raw_info=raw_info,
82
+ raw_preproc_kwargs=raw_preproc_kwargs,
83
+ window_kwargs=window_kwargs,
84
+ window_preproc_kwargs=window_preproc_kwargs,
85
+ features_kwargs=features_kwargs,
86
+ )
87
+ return dataset
@@ -0,0 +1,114 @@
1
+ from typing import Dict, List
2
+ from collections.abc import Callable
3
+ import copy
4
+ import numpy as np
5
+ import pandas as pd
6
+ from joblib import Parallel, delayed
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from braindecode.datasets.base import (
10
+ EEGWindowsDataset,
11
+ WindowsDataset,
12
+ BaseConcatDataset,
13
+ )
14
+
15
+ from .datasets import FeaturesDataset, FeaturesConcatDataset
16
+ from .extractors import FeatureExtractor
17
+
18
+
19
+ def _extract_features_from_windowsdataset(
20
+ win_ds: EEGWindowsDataset | WindowsDataset,
21
+ feature_extractor: FeatureExtractor,
22
+ batch_size: int = 512,
23
+ ):
24
+ metadata = win_ds.metadata
25
+ if not win_ds.targets_from == "metadata":
26
+ metadata = copy.deepcopy(metadata)
27
+ metadata["orig_index"] = metadata.index
28
+ metadata.set_index(
29
+ ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"],
30
+ drop=False,
31
+ inplace=True,
32
+ )
33
+ win_dl = DataLoader(win_ds, batch_size=batch_size, shuffle=False, drop_last=False)
34
+ features_dict = dict()
35
+ ch_names = win_ds.raw.ch_names
36
+ for X, y, crop_inds in win_dl:
37
+ X = X.numpy()
38
+ if hasattr(y, "tolist"):
39
+ y = y.tolist()
40
+ win_dict = dict()
41
+ win_dict.update(
42
+ feature_extractor(X, _batch_size=X.shape[0], _ch_names=ch_names)
43
+ )
44
+ if not win_ds.targets_from == "metadata":
45
+ metadata.loc[crop_inds, "target"] = y
46
+ for k, v in win_dict.items():
47
+ if k not in features_dict:
48
+ features_dict[k] = []
49
+ features_dict[k].extend(v)
50
+ features_df = pd.DataFrame(features_dict)
51
+ if not win_ds.targets_from == "metadata":
52
+ metadata.set_index("orig_index", drop=False, inplace=True)
53
+ metadata.reset_index(drop=True, inplace=True)
54
+ metadata.drop("orig_index", axis=1, inplace=True)
55
+
56
+ # FUTURE: truely support WindowsDataset objects
57
+ return FeaturesDataset(
58
+ features_df,
59
+ metadata=metadata,
60
+ description=win_ds.description,
61
+ raw_info=win_ds.raw.info,
62
+ raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
63
+ window_kwargs=win_ds.window_kwargs,
64
+ features_kwargs=feature_extractor.features_kwargs,
65
+ )
66
+
67
+
68
+ def extract_features(
69
+ concat_dataset: BaseConcatDataset,
70
+ features: FeatureExtractor | Dict[str, Callable] | List[Callable],
71
+ *,
72
+ batch_size: int = 512,
73
+ n_jobs: int = 1,
74
+ ):
75
+ if isinstance(features, list):
76
+ features = dict(enumerate(features))
77
+ if not isinstance(features, FeatureExtractor):
78
+ features = FeatureExtractor(features)
79
+ feature_ds_list = list(
80
+ tqdm(
81
+ Parallel(n_jobs=n_jobs, return_as="generator")(
82
+ delayed(_extract_features_from_windowsdataset)(
83
+ win_ds, features, batch_size
84
+ )
85
+ for win_ds in concat_dataset.datasets
86
+ ),
87
+ total=len(concat_dataset.datasets),
88
+ desc="Extracting features",
89
+ )
90
+ )
91
+ return FeaturesConcatDataset(feature_ds_list)
92
+
93
+
94
+ def fit_feature_extractors(
95
+ concat_dataset: BaseConcatDataset,
96
+ features: FeatureExtractor | Dict[str, Callable] | List[Callable],
97
+ batch_size: int = 8192,
98
+ ):
99
+ if isinstance(features, list):
100
+ features = dict(enumerate(features))
101
+ if not isinstance(features, FeatureExtractor):
102
+ features = FeatureExtractor(features)
103
+ if not features._is_fitable:
104
+ return features
105
+ features.clear()
106
+ concat_dl = DataLoader(
107
+ concat_dataset, batch_size=batch_size, shuffle=False, drop_last=False
108
+ )
109
+ for X, y, _ in tqdm(
110
+ concat_dl, total=len(concat_dl), desc="Fitting feature extractors"
111
+ ):
112
+ features.partial_fit(X.numpy(), y=np.array(y))
113
+ features.fit()
114
+ return features
eegdash/main.py CHANGED
@@ -1,15 +1,16 @@
1
- from typing import List
2
1
  import pymongo
3
2
  from dotenv import load_dotenv
4
3
  import os
5
4
  from pathlib import Path
6
5
  import s3fs
7
6
  from joblib import Parallel, delayed
7
+ import json
8
8
  import tempfile
9
9
  import mne
10
10
  import numpy as np
11
11
  import xarray as xr
12
- from .data_utils import BIDSDataset, EEGDashBaseRaw, EEGDashBaseDataset
12
+ from .data_utils import EEGBIDSDataset, EEGDashBaseRaw, EEGDashBaseDataset
13
+ from .data_config import config as data_config
13
14
  from braindecode.datasets import BaseDataset, BaseConcatDataset
14
15
  from collections import defaultdict
15
16
  from pymongo import MongoClient, InsertOne, UpdateOne, DeleteOne
@@ -18,6 +19,12 @@ class EEGDash:
18
19
  AWS_BUCKET = 's3://openneuro.org'
19
20
  def __init__(self,
20
21
  is_public=True):
22
+ # Load config file
23
+ # config_path = Path(__file__).parent / 'config.json'
24
+ # with open(config_path, 'r') as f:
25
+ # self.config = json.load(f)
26
+
27
+ self.config = data_config
21
28
  if is_public:
22
29
  DB_CONNECTION_STRING="mongodb+srv://eegdash-user:mdzoMjQcHWTVnKDq@cluster0.vz35p.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
23
30
  else:
@@ -37,10 +44,9 @@ class EEGDash:
37
44
  # convert to list using get_item on each element
38
45
  return [result for result in results]
39
46
 
40
- def exist(self, data_name=''):
41
- query = {
42
- "data_name": data_name
43
- }
47
+ def exist(self, query:dict):
48
+ accepted_query_fields = ['data_name', 'dataset']
49
+ assert all(field in accepted_query_fields for field in query.keys())
44
50
  sessions = self.find(query)
45
51
  return len(sessions) > 0
46
52
 
@@ -104,66 +110,111 @@ class EEGDash:
104
110
  )
105
111
  return eeg_xarray
106
112
 
107
- def load_eeg_attrs_from_bids_file(self, bids_dataset: BIDSDataset, bids_file):
113
+ def get_raw_extensions(self, bids_file, bids_dataset: EEGBIDSDataset):
114
+ bids_file = Path(bids_file)
115
+ extensions = {
116
+ '.set': ['.set', '.fdt'], # eeglab
117
+ '.edf': ['.edf'], # european
118
+ '.vhdr': ['.eeg', '.vhdr', '.vmrk', '.dat', '.raw'], # brainvision
119
+ '.bdf': ['.bdf'], # biosemi
120
+ }
121
+ return [str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix))) for suffix in extensions[bids_file.suffix] if bids_file.with_suffix(suffix).exists()]
122
+
123
+ def load_eeg_attrs_from_bids_file(self, bids_dataset: EEGBIDSDataset, bids_file):
108
124
  '''
109
125
  bids_file must be a file of the bids_dataset
110
126
  '''
111
127
  if bids_file not in bids_dataset.files:
112
128
  raise ValueError(f'{bids_file} not in {bids_dataset.dataset}')
129
+
130
+ # Initialize attrs with None values for all expected fields
131
+ attrs = {field: None for field in self.config['attributes'].keys()}
132
+
113
133
  f = os.path.basename(bids_file)
114
134
  dsnumber = bids_dataset.dataset
115
135
  # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
116
136
  openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
117
137
 
118
- bids_dependencies_files = ['dataset_description.json', 'participants.tsv', 'events.tsv', 'events.json', 'eeg.json', 'electrodes.tsv', 'channels.tsv', 'coordsystem.json']
138
+ # Update with actual values where available
139
+ try:
140
+ participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
141
+ except Exception as e:
142
+ print(f"Error getting participants_tsv: {str(e)}")
143
+ participants_tsv = None
144
+
145
+ try:
146
+ eeg_json = bids_dataset.eeg_json(bids_file)
147
+ except Exception as e:
148
+ print(f"Error getting eeg_json: {str(e)}")
149
+ eeg_json = None
150
+
151
+ bids_dependencies_files = self.config['bids_dependencies_files']
119
152
  bidsdependencies = []
120
153
  for extension in bids_dependencies_files:
121
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
122
- dep_path = [str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path]
123
-
124
- bidsdependencies.extend(dep_path)
125
-
126
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
127
- eeg_json = bids_dataset.eeg_json(bids_file)
128
- attrs = {
129
- 'data_name': f'{bids_dataset.dataset}_{f}',
130
- 'dataset': bids_dataset.dataset,
131
- 'bidspath': openneuro_path,
132
- 'subject': bids_dataset.subject(bids_file),
133
- 'task': bids_dataset.task(bids_file),
134
- 'session': bids_dataset.session(bids_file),
135
- 'run': bids_dataset.run(bids_file),
136
- 'modality': 'EEG',
137
- 'nchans': bids_dataset.num_channels(bids_file),
138
- 'ntimes': bids_dataset.num_times(bids_file),
139
- 'participant_tsv': participants_tsv,
140
- 'eeg_json': eeg_json,
141
- 'bidsdependencies': bidsdependencies,
154
+ try:
155
+ dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
156
+ dep_path = [str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path]
157
+ bidsdependencies.extend(dep_path)
158
+ except Exception as e:
159
+ pass
160
+
161
+ bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
162
+
163
+ # Define field extraction functions with error handling
164
+ field_extractors = {
165
+ 'data_name': lambda: f'{bids_dataset.dataset}_{f}',
166
+ 'dataset': lambda: bids_dataset.dataset,
167
+ 'bidspath': lambda: openneuro_path,
168
+ 'subject': lambda: bids_dataset.get_bids_file_attribute('subject', bids_file),
169
+ 'task': lambda: bids_dataset.get_bids_file_attribute('task', bids_file),
170
+ 'session': lambda: bids_dataset.get_bids_file_attribute('session', bids_file),
171
+ 'run': lambda: bids_dataset.get_bids_file_attribute('run', bids_file),
172
+ 'modality': lambda: bids_dataset.get_bids_file_attribute('modality', bids_file),
173
+ 'sampling_frequency': lambda: bids_dataset.get_bids_file_attribute('sfreq', bids_file),
174
+ 'nchans': lambda: bids_dataset.get_bids_file_attribute('nchans', bids_file),
175
+ 'ntimes': lambda: bids_dataset.get_bids_file_attribute('ntimes', bids_file),
176
+ 'participant_tsv': lambda: participants_tsv,
177
+ 'eeg_json': lambda: eeg_json,
178
+ 'bidsdependencies': lambda: bidsdependencies,
142
179
  }
180
+
181
+ # Dynamically populate attrs with error handling
182
+ for field, extractor in field_extractors.items():
183
+ try:
184
+ attrs[field] = extractor()
185
+ except Exception as e:
186
+ print(f"Error extracting {field}: {str(e)}")
187
+ attrs[field] = None
143
188
 
144
189
  return attrs
145
190
 
146
- def add_bids_dataset(self, dataset, data_dir, raw_format='eeglab', overwrite=True):
191
+ def add_bids_dataset(self, dataset, data_dir, overwrite=True):
147
192
  '''
148
193
  Create new records for the dataset in the MongoDB database if not found
149
194
  '''
150
195
  if self.is_public:
151
196
  raise ValueError('This operation is not allowed for public users')
152
197
 
153
- bids_dataset = BIDSDataset(
154
- data_dir=data_dir,
155
- dataset=dataset,
156
- raw_format=raw_format,
157
- )
198
+ if not overwrite and self.exist({'dataset': dataset}):
199
+ print(f'Dataset {dataset} already exists in the database')
200
+ return
201
+ try:
202
+ bids_dataset = EEGBIDSDataset(
203
+ data_dir=data_dir,
204
+ dataset=dataset,
205
+ )
206
+ except Exception as e:
207
+ print(f'Error creating bids dataset {dataset}: {str(e)}')
208
+ raise e
158
209
  requests = []
159
210
  for bids_file in bids_dataset.get_files():
160
211
  try:
161
212
  data_id = f"{dataset}_{os.path.basename(bids_file)}"
162
213
 
163
- if self.exist(data_name=data_id):
214
+ if self.exist({'data_name':data_id}):
164
215
  if overwrite:
165
216
  eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
166
- requests.append(UpdateOne(self.update_request(eeg_attrs)))
217
+ requests.append(self.update_request(eeg_attrs))
167
218
  else:
168
219
  eeg_attrs = self.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
169
220
  requests.append(self.add_request(eeg_attrs))
@@ -224,17 +275,22 @@ class EEGDash:
224
275
  def remove_field_from_db(self, field):
225
276
  self.__collection.update_many({}, {'$unset': {field: 1}})
226
277
 
278
+ @property
279
+ def collection(self):
280
+ return self.__collection
227
281
 
228
282
  class EEGDashDataset(BaseConcatDataset):
229
- CACHE_DIR = '.eegdash_cache'
283
+ # CACHE_DIR = '.eegdash_cache'
230
284
  def __init__(
231
285
  self,
232
286
  query:dict=None,
233
287
  data_dir:str | list =None,
234
288
  dataset:str | list =None,
235
289
  description_fields: list[str]=['subject', 'session', 'run', 'task', 'age', 'gender', 'sex'],
290
+ cache_dir:str='.eegdash_cache',
236
291
  **kwargs
237
292
  ):
293
+ self.cache_dir = cache_dir
238
294
  if query:
239
295
  datasets = self.find_datasets(query, description_fields, **kwargs)
240
296
  elif data_dir:
@@ -247,6 +303,7 @@ class EEGDashDataset(BaseConcatDataset):
247
303
  datasets.extend(self.load_bids_dataset(dataset[i], data_dir[i], description_fields))
248
304
  # convert to list using get_item on each element
249
305
  super().__init__(datasets)
306
+
250
307
 
251
308
  def find_key_in_nested_dict(self, data, target_key):
252
309
  if isinstance(data, dict):
@@ -267,7 +324,7 @@ class EEGDashDataset(BaseConcatDataset):
267
324
  value = self.find_key_in_nested_dict(record, field)
268
325
  if value:
269
326
  description[field] = value
270
- datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
327
+ datasets.append(EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs))
271
328
  return datasets
272
329
 
273
330
  def load_bids_dataset(self, dataset, data_dir, description_fields: list[str],raw_format='eeglab', **kwargs):
@@ -280,9 +337,9 @@ class EEGDashDataset(BaseConcatDataset):
280
337
  value = self.find_key_in_nested_dict(record, field)
281
338
  if value:
282
339
  description[field] = value
283
- return EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs)
340
+ return EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs)
284
341
 
285
- bids_dataset = BIDSDataset(
342
+ bids_dataset = EEGBIDSDataset(
286
343
  data_dir=data_dir,
287
344
  dataset=dataset,
288
345
  raw_format=raw_format,
@@ -291,15 +348,6 @@ class EEGDashDataset(BaseConcatDataset):
291
348
  datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
292
349
  delayed(get_base_dataset_from_bids_file)(bids_dataset, bids_file) for bids_file in bids_dataset.get_files()
293
350
  )
294
- # datasets = []
295
- # for bids_file in bids_dataset.get_files():
296
- # record = eegdashObj.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
297
- # description = {}
298
- # for field in description_fields:
299
- # value = self.find_key_in_nested_dict(record, field)
300
- # if value:
301
- # description[field] = value
302
- # datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
303
351
  return datasets
304
352
 
305
353
  def main():