eegdash 0.3.3.dev61__py3-none-any.whl → 0.5.0.dev180784713__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. eegdash/__init__.py +19 -6
  2. eegdash/api.py +336 -539
  3. eegdash/bids_eeg_metadata.py +495 -0
  4. eegdash/const.py +349 -0
  5. eegdash/dataset/__init__.py +28 -0
  6. eegdash/dataset/base.py +311 -0
  7. eegdash/dataset/bids_dataset.py +641 -0
  8. eegdash/dataset/dataset.py +692 -0
  9. eegdash/dataset/dataset_summary.csv +255 -0
  10. eegdash/dataset/registry.py +287 -0
  11. eegdash/downloader.py +197 -0
  12. eegdash/features/__init__.py +15 -13
  13. eegdash/features/datasets.py +329 -138
  14. eegdash/features/decorators.py +105 -13
  15. eegdash/features/extractors.py +233 -63
  16. eegdash/features/feature_bank/__init__.py +12 -12
  17. eegdash/features/feature_bank/complexity.py +22 -20
  18. eegdash/features/feature_bank/connectivity.py +27 -28
  19. eegdash/features/feature_bank/csp.py +3 -1
  20. eegdash/features/feature_bank/dimensionality.py +6 -6
  21. eegdash/features/feature_bank/signal.py +29 -30
  22. eegdash/features/feature_bank/spectral.py +40 -44
  23. eegdash/features/feature_bank/utils.py +8 -0
  24. eegdash/features/inspect.py +126 -15
  25. eegdash/features/serialization.py +58 -17
  26. eegdash/features/utils.py +90 -16
  27. eegdash/hbn/__init__.py +28 -0
  28. eegdash/hbn/preprocessing.py +105 -0
  29. eegdash/hbn/windows.py +428 -0
  30. eegdash/logging.py +54 -0
  31. eegdash/mongodb.py +55 -24
  32. eegdash/paths.py +52 -0
  33. eegdash/utils.py +29 -1
  34. eegdash-0.5.0.dev180784713.dist-info/METADATA +121 -0
  35. eegdash-0.5.0.dev180784713.dist-info/RECORD +38 -0
  36. eegdash-0.5.0.dev180784713.dist-info/licenses/LICENSE +29 -0
  37. eegdash/data_config.py +0 -34
  38. eegdash/data_utils.py +0 -687
  39. eegdash/dataset.py +0 -69
  40. eegdash/preprocessing.py +0 -63
  41. eegdash-0.3.3.dev61.dist-info/METADATA +0 -192
  42. eegdash-0.3.3.dev61.dist-info/RECORD +0 -28
  43. eegdash-0.3.3.dev61.dist-info/licenses/LICENSE +0 -23
  44. {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/WHEEL +0 -0
  45. {eegdash-0.3.3.dev61.dist-info → eegdash-0.5.0.dev180784713.dist-info}/top_level.txt +0 -0
eegdash/api.py CHANGED
@@ -1,54 +1,65 @@
1
- import logging
1
+ # Authors: The EEGDash contributors.
2
+ # License: BSD-3-Clause
3
+ # Copyright the EEGDash contributors.
4
+
5
+ """High-level interface to the EEGDash metadata database.
6
+
7
+ This module provides the main EEGDash class which serves as the primary entry point for
8
+ interacting with the EEGDash ecosystem. It offers methods to query, insert, and update
9
+ metadata records stored in the EEGDash MongoDB database, and includes utilities to load
10
+ EEG data from S3 for matched records.
11
+ """
12
+
13
+ import json
2
14
  import os
3
- import tempfile
4
15
  from pathlib import Path
5
16
  from typing import Any, Mapping
6
17
 
7
18
  import mne
8
19
  import numpy as np
9
- import xarray as xr
10
- from dotenv import load_dotenv
11
- from joblib import Parallel, delayed
12
- from pymongo import InsertOne, UpdateOne
13
- from s3fs import S3FileSystem
14
-
15
- from braindecode.datasets import BaseConcatDataset
16
-
17
- from .data_config import config as data_config
18
- from .data_utils import EEGBIDSDataset, EEGDashBaseDataset
20
+ import pandas as pd
21
+ from mne.utils import _soft_import
22
+
23
+ from .bids_eeg_metadata import (
24
+ build_query_from_kwargs,
25
+ load_eeg_attrs_from_bids_file,
26
+ )
27
+ from .const import (
28
+ ALLOWED_QUERY_FIELDS,
29
+ )
30
+ from .const import config as data_config
31
+ from .dataset.bids_dataset import EEGBIDSDataset
32
+ from .logging import logger
19
33
  from .mongodb import MongoConnectionManager
20
-
21
- logger = logging.getLogger("eegdash")
34
+ from .utils import _init_mongo_client
22
35
 
23
36
 
24
37
  class EEGDash:
25
- """A high-level interface to the EEGDash database.
26
-
27
- This class is primarily used to interact with the metadata records stored in the
28
- EEGDash database (or a private instance of it), allowing users to find, add, and
29
- update EEG data records.
38
+ """High-level interface to the EEGDash metadata database.
30
39
 
31
- While this class provides basic support for loading EEG data, please see
32
- the EEGDashDataset class for a more complete way to retrieve and work with full
33
- datasets.
40
+ Provides methods to query, insert, and update metadata records stored in the
41
+ EEGDash MongoDB database (public or private). Also includes utilities to load
42
+ EEG data from S3 for matched records.
34
43
 
44
+ For working with collections of
45
+ recordings as PyTorch datasets, prefer :class:`EEGDashDataset`.
35
46
  """
36
47
 
37
48
  def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
38
- """Create new instance of the EEGDash Database client.
49
+ """Create a new EEGDash client.
39
50
 
40
51
  Parameters
41
52
  ----------
42
- is_public: bool
43
- Whether to connect to the public MongoDB database; if False, connect to a
44
- private database instance as per the DB_CONNECTION_STRING env variable
45
- (or .env file entry).
46
- is_staging: bool
47
- If True, use staging MongoDB database ("eegdashstaging"); otherwise use the
48
- production database ("eegdash").
49
-
50
- Example
51
- -------
53
+ is_public : bool, default True
54
+ Connect to the public MongoDB database. If ``False``, connect to a
55
+ private database instance using the ``DB_CONNECTION_STRING`` environment
56
+ variable (or value from a ``.env`` file).
57
+ is_staging : bool, default False
58
+ If ``True``, use the staging database (``eegdashstaging``); otherwise
59
+ use the production database (``eegdash``).
60
+
61
+ Examples
62
+ --------
52
63
  >>> eegdash = EEGDash()
53
64
 
54
65
  """
@@ -58,47 +69,91 @@ class EEGDash:
58
69
 
59
70
  if self.is_public:
60
71
  DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
72
+ if not DB_CONNECTION_STRING:
73
+ try:
74
+ _init_mongo_client()
75
+ DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
76
+ except Exception:
77
+ DB_CONNECTION_STRING = None
61
78
  else:
62
- load_dotenv()
79
+ dotenv = _soft_import("dotenv", "eegdash[full] is necessary.")
80
+ dotenv.load_dotenv()
63
81
  DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
64
82
 
65
83
  # Use singleton to get MongoDB client, database, and collection
84
+ if not DB_CONNECTION_STRING:
85
+ raise RuntimeError(
86
+ "No MongoDB connection string configured. Set MNE config 'EEGDASH_DB_URI' "
87
+ "or environment variable 'DB_CONNECTION_STRING'."
88
+ )
66
89
  self.__client, self.__db, self.__collection = MongoConnectionManager.get_client(
67
90
  DB_CONNECTION_STRING, is_staging
68
91
  )
69
92
 
70
- self.filesystem = S3FileSystem(
71
- anon=True, client_kwargs={"region_name": "us-east-2"}
72
- )
93
+ def find(
94
+ self, query: dict[str, Any] = None, /, **kwargs
95
+ ) -> list[Mapping[str, Any]]:
96
+ """Find records in the MongoDB collection.
73
97
 
74
- def find(self, query: dict[str, Any], *args, **kwargs) -> list[Mapping[str, Any]]:
75
- """Find records in the MongoDB collection that satisfy the given query.
98
+ Examples
99
+ --------
100
+ >>> eegdash.find({"dataset": "ds002718", "subject": {"$in": ["012", "013"]}}) # pre-built query
101
+ >>> eegdash.find(dataset="ds002718", subject="012") # keyword filters
102
+ >>> eegdash.find(dataset="ds002718", subject=["012", "013"]) # sequence -> $in
103
+ >>> eegdash.find({}) # fetch all (use with care)
104
+ >>> eegdash.find({"dataset": "ds002718"}, subject=["012", "013"]) # combine query + kwargs (AND)
76
105
 
77
106
  Parameters
78
107
  ----------
79
- query: dict
80
- A dictionary that specifies the query to be executed; this is a reference
81
- document that is used to match records in the MongoDB collection.
82
- args:
83
- Additional positional arguments for the MongoDB find() method; see
84
- https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
85
- kwargs:
86
- Additional keyword arguments for the MongoDB find() method.
108
+ query : dict, optional
109
+ Complete MongoDB query dictionary. This is a positional-only
110
+ argument.
111
+ **kwargs
112
+ User-friendly field filters that are converted to a MongoDB query.
113
+ Values can be scalars (e.g., ``"sub-01"``) or sequences (translated
114
+ to ``$in`` queries).
87
115
 
88
116
  Returns
89
117
  -------
90
- list:
91
- A list of DB records (string-keyed dictionaries) that match the query.
92
-
93
- Example
94
- -------
95
- >>> eegdash = EEGDash()
96
- >>> eegdash.find({"dataset": "ds002718", "subject": "012"})
118
+ list of dict
119
+ DB records that match the query.
97
120
 
98
121
  """
99
- results = self.__collection.find(query, *args, **kwargs)
122
+ final_query: dict[str, Any] | None = None
123
+
124
+ # Accept explicit empty dict {} to mean "match all"
125
+ raw_query = query if isinstance(query, dict) else None
126
+ kwargs_query = build_query_from_kwargs(**kwargs) if kwargs else None
127
+
128
+ # Determine presence, treating {} as a valid raw query
129
+ has_raw = isinstance(raw_query, dict)
130
+ has_kwargs = kwargs_query is not None
131
+
132
+ if has_raw and has_kwargs:
133
+ # Detect conflicting constraints on the same field (e.g., task specified
134
+ # differently in both places) and raise a clear error instead of silently
135
+ # producing an empty result.
136
+ self._raise_if_conflicting_constraints(raw_query, kwargs_query)
137
+ # Merge with logical AND so both constraints apply
138
+ if raw_query: # non-empty dict adds constraints
139
+ final_query = {"$and": [raw_query, kwargs_query]}
140
+ else: # {} adds nothing; use kwargs_query only
141
+ final_query = kwargs_query
142
+ elif has_raw:
143
+ # May be {} meaning match-all, or a non-empty dict
144
+ final_query = raw_query
145
+ elif has_kwargs:
146
+ final_query = kwargs_query
147
+ else:
148
+ # Avoid accidental full scans
149
+ raise ValueError(
150
+ "find() requires a query dictionary or at least one keyword argument. "
151
+ "To find all documents, use find({})."
152
+ )
100
153
 
101
- return [result for result in results]
154
+ results = self.__collection.find(final_query)
155
+
156
+ return list(results)
102
157
 
103
158
  def exist(self, query: dict[str, Any]) -> bool:
104
159
  """Return True if at least one record matches the query, else False.
@@ -145,17 +200,22 @@ class EEGDash:
145
200
  return doc is not None
146
201
 
147
202
  def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
148
- """Internal method to validate the input record against the expected schema.
203
+ """Validate the input record against the expected schema.
149
204
 
150
205
  Parameters
151
206
  ----------
152
- record: dict
207
+ record : dict
153
208
  A dictionary representing the EEG data record to be validated.
154
209
 
155
210
  Returns
156
211
  -------
157
- dict:
158
- Returns the record itself on success, or raises a ValueError if the record is invalid.
212
+ dict
213
+ The record itself on success.
214
+
215
+ Raises
216
+ ------
217
+ ValueError
218
+ If the record is missing required keys or has values of the wrong type.
159
219
 
160
220
  """
161
221
  input_types = {
@@ -184,548 +244,285 @@ class EEGDash:
184
244
 
185
245
  return record
186
246
 
187
- def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
188
- """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
247
+ def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
248
+ """Build a validated MongoDB query from keyword arguments.
249
+
250
+ This delegates to the module-level builder used across the package.
189
251
 
190
252
  Parameters
191
253
  ----------
192
- s3path : str
193
- An S3 URI (should start with "s3://") for the file in question.
254
+ **kwargs
255
+ Keyword arguments to convert into a MongoDB query.
194
256
 
195
257
  Returns
196
258
  -------
197
- xr.DataArray
198
- A DataArray containing the EEG data, with dimensions "channel" and "time".
199
-
200
- Example
201
- -------
202
- >>> eegdash = EEGDash()
203
- >>> mypath = "s3://openneuro.org/path/to/your/eeg_data.set"
204
- >>> mydata = eegdash.load_eeg_data_from_s3(mypath)
259
+ dict
260
+ A MongoDB query dictionary.
205
261
 
206
262
  """
207
- with tempfile.NamedTemporaryFile(delete=False, suffix=".set") as tmp:
208
- with self.filesystem.open(s3path) as s3_file:
209
- tmp.write(s3_file.read())
210
- tmp_path = tmp.name
211
- eeg_data = self.load_eeg_data_from_bids_file(tmp_path)
212
- os.unlink(tmp_path)
213
- return eeg_data
214
-
215
- def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
216
- """Load EEG data from a local file and return it as a xarray DataArray.
217
-
218
- Parameters
219
- ----------
220
- bids_file : str
221
- Path to the file on the local filesystem.
222
-
223
- Notes
224
- -----
225
- Currently, only non-epoched .set files are supported.
226
-
227
- """
228
- raw_object = mne.io.read_raw(bids_file)
229
- eeg_data = raw_object.get_data()
230
-
231
- fs = raw_object.info["sfreq"]
232
- max_time = eeg_data.shape[1] / fs
233
- time_steps = np.linspace(0, max_time, eeg_data.shape[1]).squeeze() # in seconds
234
-
235
- channel_names = raw_object.ch_names
263
+ return build_query_from_kwargs(**kwargs)
236
264
 
237
- eeg_xarray = xr.DataArray(
238
- data=eeg_data,
239
- dims=["channel", "time"],
240
- coords={"time": time_steps, "channel": channel_names},
241
- )
242
- return eeg_xarray
243
-
244
- def get_raw_extensions(
245
- self, bids_file: str, bids_dataset: EEGBIDSDataset
246
- ) -> list[str]:
247
- """Helper to find paths to additional "sidecar" files that may be associated
248
- with a given main data file in a BIDS dataset; paths are returned as relative to
249
- the parent dataset path.
265
+ def _extract_simple_constraint(
266
+ self, query: dict[str, Any], key: str
267
+ ) -> tuple[str, Any] | None:
268
+ """Extract a simple constraint for a given key from a query dict.
250
269
 
251
- For example, if the input file is a .set file, this will return the relative path
252
- to a corresponding .fdt file (if any).
253
- """
254
- bids_file = Path(bids_file)
255
- extensions = {
256
- ".set": [".set", ".fdt"], # eeglab
257
- ".edf": [".edf"], # european
258
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
259
- ".bdf": [".bdf"], # biosemi
260
- }
261
- return [
262
- str(bids_dataset.get_relative_bidspath(bids_file.with_suffix(suffix)))
263
- for suffix in extensions[bids_file.suffix]
264
- if bids_file.with_suffix(suffix).exists()
265
- ]
266
-
267
- def load_eeg_attrs_from_bids_file(
268
- self, bids_dataset: EEGBIDSDataset, bids_file: str
269
- ) -> dict[str, Any]:
270
- """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
271
-
272
- Attributes are at least the ones defined in data_config attributes (set to None if missing),
273
- but are typically a superset, and include, among others, the paths to relevant
274
- meta-data files needed to load and interpret the file in question.
270
+ Supports top-level equality (e.g., ``{'subject': '01'}``) and ``$in``
271
+ (e.g., ``{'subject': {'$in': ['01', '02']}}``) constraints.
275
272
 
276
273
  Parameters
277
274
  ----------
278
- bids_dataset : EEGBIDSDataset
279
- The BIDS dataset object containing the file.
280
- bids_file : str
281
- The path to the BIDS file within the dataset.
275
+ query : dict
276
+ The MongoDB query dictionary.
277
+ key : str
278
+ The key for which to extract the constraint.
282
279
 
283
280
  Returns
284
281
  -------
285
- dict:
286
- A dictionary representing the metadata record for the given file. This is the
287
- same format as the records stored in the database.
282
+ tuple or None
283
+ A tuple of (kind, value) where kind is "eq" or "in", or None if the
284
+ constraint is not present or unsupported.
288
285
 
289
286
  """
290
- if bids_file not in bids_dataset.files:
291
- raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
287
+ if not isinstance(query, dict) or key not in query:
288
+ return None
289
+ val = query[key]
290
+ if isinstance(val, dict):
291
+ if "$in" in val and isinstance(val["$in"], (list, tuple)):
292
+ return ("in", list(val["$in"]))
293
+ return None # unsupported operator shape for conflict checking
294
+ else:
295
+ return "eq", val
292
296
 
293
- # Initialize attrs with None values for all expected fields
294
- attrs = {field: None for field in self.config["attributes"].keys()}
297
+ def _raise_if_conflicting_constraints(
298
+ self, raw_query: dict[str, Any], kwargs_query: dict[str, Any]
299
+ ) -> None:
300
+ """Raise ValueError if query sources have incompatible constraints.
295
301
 
296
- file = Path(bids_file).name
297
- dsnumber = bids_dataset.dataset
298
- # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
299
- openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
302
+ Checks for mutually exclusive constraints on the same field to avoid
303
+ silent empty results.
300
304
 
301
- # Update with actual values where available
302
- try:
303
- participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
304
- except Exception as e:
305
- logger.error("Error getting participants_tsv: %s", str(e))
306
- participants_tsv = None
305
+ Parameters
306
+ ----------
307
+ raw_query : dict
308
+ The raw MongoDB query dictionary.
309
+ kwargs_query : dict
310
+ The query dictionary built from keyword arguments.
307
311
 
308
- try:
309
- eeg_json = bids_dataset.eeg_json(bids_file)
310
- except Exception as e:
311
- logger.error("Error getting eeg_json: %s", str(e))
312
- eeg_json = None
313
-
314
- bids_dependencies_files = self.config["bids_dependencies_files"]
315
- bidsdependencies = []
316
- for extension in bids_dependencies_files:
317
- try:
318
- dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
319
- dep_path = [
320
- str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
321
- ]
322
- bidsdependencies.extend(dep_path)
323
- except Exception:
324
- pass
325
-
326
- bidsdependencies.extend(self.get_raw_extensions(bids_file, bids_dataset))
327
-
328
- # Define field extraction functions with error handling
329
- field_extractors = {
330
- "data_name": lambda: f"{bids_dataset.dataset}_{file}",
331
- "dataset": lambda: bids_dataset.dataset,
332
- "bidspath": lambda: openneuro_path,
333
- "subject": lambda: bids_dataset.get_bids_file_attribute(
334
- "subject", bids_file
335
- ),
336
- "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
337
- "session": lambda: bids_dataset.get_bids_file_attribute(
338
- "session", bids_file
339
- ),
340
- "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
341
- "modality": lambda: bids_dataset.get_bids_file_attribute(
342
- "modality", bids_file
343
- ),
344
- "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
345
- "sfreq", bids_file
346
- ),
347
- "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
348
- "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
349
- "participant_tsv": lambda: participants_tsv,
350
- "eeg_json": lambda: eeg_json,
351
- "bidsdependencies": lambda: bidsdependencies,
352
- }
312
+ Raises
313
+ ------
314
+ ValueError
315
+ If conflicting constraints are found.
353
316
 
354
- # Dynamically populate attrs with error handling
355
- for field, extractor in field_extractors.items():
356
- try:
357
- attrs[field] = extractor()
358
- except Exception as e:
359
- logger.error("Error extracting %s : %s", field, str(e))
360
- attrs[field] = None
317
+ """
318
+ if not raw_query or not kwargs_query:
319
+ return
361
320
 
362
- return attrs
321
+ # Only consider fields we generally allow; skip meta operators like $and
322
+ raw_keys = set(raw_query.keys()) & ALLOWED_QUERY_FIELDS
323
+ kw_keys = set(kwargs_query.keys()) & ALLOWED_QUERY_FIELDS
324
+ dup_keys = raw_keys & kw_keys
325
+ for key in dup_keys:
326
+ rc = self._extract_simple_constraint(raw_query, key)
327
+ kc = self._extract_simple_constraint(kwargs_query, key)
328
+ if rc is None or kc is None:
329
+ # If either side is non-simple, skip conflict detection for this key
330
+ continue
331
+
332
+ r_kind, r_val = rc
333
+ k_kind, k_val = kc
334
+
335
+ # Normalize to sets when appropriate for simpler checks
336
+ if r_kind == "eq" and k_kind == "eq":
337
+ if r_val != k_val:
338
+ raise ValueError(
339
+ f"Conflicting constraints for '{key}': query={r_val!r} vs kwargs={k_val!r}"
340
+ )
341
+ elif r_kind == "in" and k_kind == "eq":
342
+ if k_val not in r_val:
343
+ raise ValueError(
344
+ f"Conflicting constraints for '{key}': query in {r_val!r} vs kwargs={k_val!r}"
345
+ )
346
+ elif r_kind == "eq" and k_kind == "in":
347
+ if r_val not in k_val:
348
+ raise ValueError(
349
+ f"Conflicting constraints for '{key}': query={r_val!r} vs kwargs in {k_val!r}"
350
+ )
351
+ elif r_kind == "in" and k_kind == "in":
352
+ if len(set(r_val).intersection(k_val)) == 0:
353
+ raise ValueError(
354
+ f"Conflicting constraints for '{key}': disjoint sets {r_val!r} and {k_val!r}"
355
+ )
363
356
 
364
357
  def add_bids_dataset(
365
- self, dataset: str, data_dir: str, overwrite: bool = True
366
- ) -> None:
367
- """Traverse the BIDS dataset at data_dir and add its records to the MongoDB database,
368
- under the given dataset name.
358
+ self,
359
+ dataset: str,
360
+ data_dir: str,
361
+ overwrite: bool = True,
362
+ output_path: str | Path | None = None,
363
+ ) -> dict[str, Any]:
364
+ """Collect metadata for a local BIDS dataset as JSON-ready records.
365
+
366
+ Instead of inserting records directly into MongoDB, this method scans
367
+ ``data_dir`` and returns a JSON-serializable manifest describing every
368
+ EEG recording that was discovered. The manifest can be written to disk
369
+ or forwarded to the EEGDash ingestion API for persistence.
369
370
 
370
371
  Parameters
371
372
  ----------
372
- dataset : str)
373
- The name of the dataset to be added (e.g., "ds002718").
373
+ dataset : str
374
+ Dataset identifier (e.g., ``"ds002718"``).
374
375
  data_dir : str
375
- The path to the BIDS dataset directory.
376
- overwrite : bool
377
- Whether to overwrite/update existing records in the database.
376
+ Path to the local BIDS dataset directory.
377
+ overwrite : bool, default True
378
+ If ``False``, skip records that already exist in the database based
379
+ on ``data_name`` lookups.
380
+ output_path : str | Path | None, optional
381
+ If provided, the manifest is written to the given JSON file.
378
382
 
379
- """
380
- if self.is_public:
381
- raise ValueError("This operation is not allowed for public users")
383
+ Returns
384
+ -------
385
+ dict
386
+ A manifest with keys ``dataset``, ``source``, ``records`` and, when
387
+ applicable, ``skipped`` or ``errors``.
382
388
 
383
- if not overwrite and self.exist({"dataset": dataset}):
384
- logger.info("Dataset %s already exists in the database", dataset)
385
- return
389
+ """
390
+ source_dir = Path(data_dir).expanduser()
386
391
  try:
387
392
  bids_dataset = EEGBIDSDataset(
388
- data_dir=data_dir,
393
+ data_dir=str(source_dir),
389
394
  dataset=dataset,
390
395
  )
391
- except Exception as e:
392
- logger.error("Error creating bids dataset %s: $s", dataset, str(e))
393
- raise e
394
- requests = []
396
+ except Exception as exc:
397
+ logger.error("Error creating BIDS dataset %s: %s", dataset, exc)
398
+ raise exc
399
+
400
+ records: list[dict[str, Any]] = []
401
+ skipped: list[str] = []
402
+ errors: list[dict[str, str]] = []
403
+
395
404
  for bids_file in bids_dataset.get_files():
396
- try:
397
- data_id = f"{dataset}_{Path(bids_file).name}"
398
-
399
- if self.exist({"data_name": data_id}):
400
- if overwrite:
401
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
402
- bids_dataset, bids_file
403
- )
404
- requests.append(self.update_request(eeg_attrs))
405
- else:
406
- eeg_attrs = self.load_eeg_attrs_from_bids_file(
407
- bids_dataset, bids_file
405
+ data_id = f"{dataset}_{Path(bids_file).name}"
406
+ if not overwrite:
407
+ try:
408
+ if self.exist({"data_name": data_id}):
409
+ skipped.append(data_id)
410
+ continue
411
+ except Exception as exc:
412
+ logger.warning(
413
+ "Could not verify existing record %s due to: %s",
414
+ data_id,
415
+ exc,
408
416
  )
409
- requests.append(self.add_request(eeg_attrs))
410
- except Exception as e:
411
- logger.error("Error adding record %s", bids_file)
412
- logger.error(str(e))
413
417
 
414
- logger.info("Number of requests: %s", len(requests))
418
+ try:
419
+ eeg_attrs = load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
420
+ records.append(eeg_attrs)
421
+ except Exception as exc: # log and continue collecting
422
+ logger.error("Error extracting metadata for %s", bids_file)
423
+ logger.error(str(exc))
424
+ errors.append({"file": str(bids_file), "error": str(exc)})
425
+
426
+ manifest: dict[str, Any] = {
427
+ "dataset": dataset,
428
+ "source": str(source_dir.resolve()),
429
+ "record_count": len(records),
430
+ "records": records,
431
+ }
432
+ if skipped:
433
+ manifest["skipped"] = skipped
434
+ if errors:
435
+ manifest["errors"] = errors
436
+
437
+ if output_path is not None:
438
+ output_path = Path(output_path)
439
+ output_path.parent.mkdir(parents=True, exist_ok=True)
440
+ with output_path.open("w", encoding="utf-8") as fh:
441
+ json.dump(
442
+ manifest,
443
+ fh,
444
+ indent=2,
445
+ sort_keys=True,
446
+ default=_json_default,
447
+ )
448
+ logger.info(
449
+ "Wrote EEGDash ingestion manifest for %s to %s",
450
+ dataset,
451
+ output_path,
452
+ )
453
+
454
+ logger.info(
455
+ "Prepared %s records for dataset %s (skipped=%s, errors=%s)",
456
+ len(records),
457
+ dataset,
458
+ len(skipped),
459
+ len(errors),
460
+ )
461
+
462
+ return manifest
415
463
 
416
- if requests:
417
- result = self.__collection.bulk_write(requests, ordered=False)
418
- logger.info("Inserted: %s ", result.inserted_count)
419
- logger.info("Modified: %s ", result.modified_count)
420
- logger.info("Deleted: %s", result.deleted_count)
421
- logger.info("Upserted: %s", result.upserted_count)
422
- logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
464
+ def exists(self, query: dict[str, Any]) -> bool:
465
+ """Check if at least one record matches the query.
423
466
 
424
- def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
425
- """Retrieve a list of EEG data arrays that match the given query. See also
426
- the `find()` method for details on the query format.
467
+ This is an alias for :meth:`exist`.
427
468
 
428
469
  Parameters
429
470
  ----------
430
471
  query : dict
431
- A dictionary that specifies the query to be executed; this is a reference
432
- document that is used to match records in the MongoDB collection.
472
+ MongoDB query to check for existence.
433
473
 
434
474
  Returns
435
475
  -------
436
- A list of xarray DataArray objects containing the EEG data for each matching record.
437
-
438
- Notes
439
- -----
440
- Retrieval is done in parallel, and the downloaded data are not cached locally.
441
-
442
- """
443
- sessions = self.find(query)
444
- results = []
445
- if sessions:
446
- logger.info("Found %s records", len(sessions))
447
- results = Parallel(
448
- n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
449
- )(
450
- delayed(self.load_eeg_data_from_s3)(self.get_s3path(session))
451
- for session in sessions
452
- )
453
- return results
454
-
455
- def add_request(self, record: dict):
456
- """Internal helper method to create a MongoDB insertion request for a record."""
457
- return InsertOne(record)
458
-
459
- def add(self, record: dict):
460
- """Add a single record to the MongoDB collection."""
461
- try:
462
- self.__collection.insert_one(record)
463
- except ValueError as e:
464
- logger.error("Validation error for record: %s ", record["data_name"])
465
- logger.error(e)
466
- except:
467
- logger.error("Error adding record: %s ", record["data_name"])
468
-
469
- def update_request(self, record: dict):
470
- """Internal helper method to create a MongoDB update request for a record."""
471
- return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
472
-
473
- def update(self, record: dict):
474
- """Update a single record in the MongoDB collection."""
475
- try:
476
- self.__collection.update_one(
477
- {"data_name": record["data_name"]}, {"$set": record}
478
- )
479
- except: # silent failure
480
- logger.error("Error updating record: %s", record["data_name"])
481
-
482
- def remove_field(self, record, field):
483
- """Remove a specific field from a record in the MongoDB collection."""
484
- self.__collection.update_one(
485
- {"data_name": record["data_name"]}, {"$unset": {field: 1}}
486
- )
476
+ bool
477
+ True if a matching record exists, False otherwise.
487
478
 
488
- def remove_field_from_db(self, field):
489
- """Removed all occurrences of a specific field from all records in the MongoDB
490
- collection. WARNING: this operation is destructive and should be used with caution.
491
479
  """
492
- self.__collection.update_many({}, {"$unset": {field: 1}})
480
+ return self.exist(query)
493
481
 
494
482
  @property
495
483
  def collection(self):
496
- """Return the MongoDB collection object."""
497
- return self.__collection
484
+ """The underlying PyMongo ``Collection`` object.
498
485
 
499
- def close(self):
500
- """Close the MongoDB client connection.
486
+ Returns
487
+ -------
488
+ pymongo.collection.Collection
489
+ The collection object used for database interactions.
501
490
 
502
- Note: Since MongoDB clients are now managed by a singleton,
503
- this method no longer closes connections. Use close_all_connections()
504
- class method to close all connections if needed.
505
491
  """
506
- # Individual instances no longer close the shared client
507
- pass
492
+ return self.__collection
508
493
 
509
494
  @classmethod
510
- def close_all_connections(cls):
511
- """Close all MongoDB client connections managed by the singleton."""
495
+ def close_all_connections(cls) -> None:
496
+ """Close all MongoDB client connections managed by the singleton manager."""
512
497
  MongoConnectionManager.close_all()
513
498
 
514
- def __del__(self):
515
- """Ensure connection is closed when object is deleted."""
516
- # No longer needed since we're using singleton pattern
517
- pass
518
-
519
-
520
- class EEGDashDataset(BaseConcatDataset):
521
- def __init__(
522
- self,
523
- query: dict | None = None,
524
- data_dir: str | list | None = None,
525
- dataset: str | list | None = None,
526
- description_fields: list[str] = [
527
- "subject",
528
- "session",
529
- "run",
530
- "task",
531
- "age",
532
- "gender",
533
- "sex",
534
- ],
535
- cache_dir: str = "~/eegdash_cache",
536
- s3_bucket: str | None = None,
537
- eeg_dash_instance=None,
538
- **kwargs,
539
- ):
540
- """Create a new EEGDashDataset from a given query or local BIDS dataset directory
541
- and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
542
- instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
543
-
544
- Parameters
545
- ----------
546
- query : dict | None
547
- Optionally a dictionary that specifies the query to be executed; see
548
- EEGDash.find() for details on the query format.
549
- data_dir : str | list[str] | None
550
- Optionally a string or a list of strings specifying one or more local
551
- BIDS dataset directories from which to load the EEG data files. Exactly one
552
- of query or data_dir must be provided.
553
- dataset : str | list[str] | None
554
- If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
555
- description_fields : list[str]
556
- A list of fields to be extracted from the dataset records
557
- and included in the returned data description(s). Examples are typical
558
- subject metadata fields such as "subject", "session", "run", "task", etc.;
559
- see also data_config.description_fields for the default set of fields.
560
- cache_dir : str
561
- A directory where the dataset will be cached locally.
562
- s3_bucket : str | None
563
- An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
564
- default OpenNeuro bucket for loading data files
565
- kwargs : dict
566
- Additional keyword arguments to be passed to the EEGDashBaseDataset
567
- constructor.
568
-
569
- """
570
- self.cache_dir = cache_dir
571
- self.s3_bucket = s3_bucket
572
- self.eeg_dash = eeg_dash_instance or EEGDash()
573
- _owns_client = eeg_dash_instance is None
574
-
575
- try:
576
- if query:
577
- datasets = self.find_datasets(query, description_fields, **kwargs)
578
- elif data_dir:
579
- if isinstance(data_dir, str):
580
- datasets = self.load_bids_dataset(
581
- dataset, data_dir, description_fields, s3_bucket, **kwargs
582
- )
583
- else:
584
- assert len(data_dir) == len(dataset), (
585
- "Number of datasets and their directories must match"
586
- )
587
- datasets = []
588
- for i, _ in enumerate(data_dir):
589
- datasets.extend(
590
- self.load_bids_dataset(
591
- dataset[i],
592
- data_dir[i],
593
- description_fields,
594
- s3_bucket,
595
- **kwargs,
596
- )
597
- )
598
- else:
599
- raise ValueError(
600
- "Exactly one of 'query' or 'data_dir' must be provided."
601
- )
602
- finally:
603
- # If we created the client, close it now that construction is done.
604
- if _owns_client:
605
- try:
606
- self.eeg_dash.close()
607
- except Exception:
608
- # Don't let close errors break construction
609
- pass
610
-
611
- self.filesystem = S3FileSystem(
612
- anon=True, client_kwargs={"region_name": "us-east-2"}
613
- )
614
-
615
- self.eeg_dash.close()
616
-
617
- super().__init__(datasets)
618
-
619
- def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
620
- """Helper to recursively search for a key in a nested dictionary structure; returns
621
- the value associated with the first occurrence of the key, or None if not found.
622
- """
623
- if isinstance(data, dict):
624
- if target_key in data:
625
- return data[target_key]
626
- for value in data.values():
627
- result = self.find_key_in_nested_dict(value, target_key)
628
- if result is not None:
629
- return result
630
- return None
631
-
632
- def find_datasets(
633
- self, query: dict[str, Any], description_fields: list[str], **kwargs
634
- ) -> list[EEGDashBaseDataset]:
635
- """Helper method to find datasets in the MongoDB collection that satisfy the
636
- given query and return them as a list of EEGDashBaseDataset objects.
637
-
638
- Parameters
639
- ----------
640
- query : dict
641
- The query object, as in EEGDash.find().
642
- description_fields : list[str]
643
- A list of fields to be extracted from the dataset records and included in
644
- the returned dataset description(s).
645
- kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
646
- constructor.
647
499
 
648
- Returns
649
- -------
650
- list :
651
- A list of EEGDashBaseDataset objects that match the query.
500
+ def _json_default(value: Any) -> Any:
501
+ """Fallback serializer for complex objects when exporting ingestion JSON."""
502
+ try:
503
+ if isinstance(value, (np.generic,)):
504
+ return value.item()
505
+ if isinstance(value, np.ndarray):
506
+ return value.tolist()
507
+ except Exception:
508
+ pass
652
509
 
653
- """
654
- datasets: list[EEGDashBaseDataset] = []
655
- for record in self.eeg_dash.find(query):
656
- description = {}
657
- for field in description_fields:
658
- value = self.find_key_in_nested_dict(record, field)
659
- if value is not None:
660
- description[field] = value
661
- datasets.append(
662
- EEGDashBaseDataset(
663
- record,
664
- self.cache_dir,
665
- self.s3_bucket,
666
- description=description,
667
- **kwargs,
668
- )
669
- )
670
- return datasets
510
+ try:
511
+ if value is pd.NA:
512
+ return None
513
+ if isinstance(value, (pd.Timestamp, pd.Timedelta)):
514
+ return value.isoformat()
515
+ if isinstance(value, pd.Series):
516
+ return value.to_dict()
517
+ except Exception:
518
+ pass
671
519
 
672
- def load_bids_dataset(
673
- self,
674
- dataset,
675
- data_dir,
676
- description_fields: list[str],
677
- s3_bucket: str | None = None,
678
- **kwargs,
679
- ):
680
- """Helper method to load a single local BIDS dataset and return it as a list of
681
- EEGDashBaseDatasets (one for each recording in the dataset).
520
+ if isinstance(value, Path):
521
+ return value.as_posix()
522
+ if isinstance(value, set):
523
+ return sorted(value)
682
524
 
683
- Parameters
684
- ----------
685
- dataset : str
686
- A name for the dataset to be loaded (e.g., "ds002718").
687
- data_dir : str
688
- The path to the local BIDS dataset directory.
689
- description_fields : list[str]
690
- A list of fields to be extracted from the dataset records
691
- and included in the returned dataset description(s).
525
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
692
526
 
693
- """
694
- bids_dataset = EEGBIDSDataset(
695
- data_dir=data_dir,
696
- dataset=dataset,
697
- )
698
- datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
699
- delayed(self.get_base_dataset_from_bids_file)(
700
- bids_dataset=bids_dataset,
701
- bids_file=bids_file,
702
- s3_bucket=s3_bucket,
703
- description_fields=description_fields,
704
- **kwargs,
705
- )
706
- for bids_file in bids_dataset.get_files()
707
- )
708
- return datasets
709
527
 
710
- def get_base_dataset_from_bids_file(
711
- self,
712
- bids_dataset: "EEGBIDSDataset",
713
- bids_file: str,
714
- s3_bucket: str | None,
715
- description_fields: list[str],
716
- **kwargs,
717
- ) -> "EEGDashBaseDataset":
718
- """Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
719
- record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
720
- description = {}
721
- for field in description_fields:
722
- value = self.find_key_in_nested_dict(record, field)
723
- if value is not None:
724
- description[field] = value
725
- return EEGDashBaseDataset(
726
- record,
727
- self.cache_dir,
728
- s3_bucket,
729
- description=description,
730
- **kwargs,
731
- )
528
+ __all__ = ["EEGDash"]