eegdash 0.3.9.dev182388821__py3-none-any.whl → 0.4.0.dev132__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/__init__.py CHANGED
@@ -7,4 +7,4 @@ _init_mongo_client()
7
7
 
8
8
  __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset", "preprocessing"]
9
9
 
10
- __version__ = "0.3.9.dev182388821"
10
+ __version__ = "0.4.0.dev132"
eegdash/api.py CHANGED
@@ -1,9 +1,6 @@
1
- import logging
2
1
  import os
3
- import tempfile
4
2
  from pathlib import Path
5
3
  from typing import Any, Mapping
6
- from urllib.parse import urlsplit
7
4
 
8
5
  import mne
9
6
  import numpy as np
@@ -11,13 +8,15 @@ import xarray as xr
11
8
  from docstring_inheritance import NumpyDocstringInheritanceInitMeta
12
9
  from dotenv import load_dotenv
13
10
  from joblib import Parallel, delayed
14
- from mne.utils import warn
15
11
  from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
16
12
  from pymongo import InsertOne, UpdateOne
17
- from s3fs import S3FileSystem
13
+ from rich.console import Console
14
+ from rich.panel import Panel
15
+ from rich.text import Text
18
16
 
19
17
  from braindecode.datasets import BaseConcatDataset
20
18
 
19
+ from . import downloader
21
20
  from .bids_eeg_metadata import (
22
21
  build_query_from_kwargs,
23
22
  load_eeg_attrs_from_bids_file,
@@ -33,10 +32,10 @@ from .data_utils import (
33
32
  EEGBIDSDataset,
34
33
  EEGDashBaseDataset,
35
34
  )
35
+ from .logging import logger
36
36
  from .mongodb import MongoConnectionManager
37
37
  from .paths import get_default_cache_dir
38
-
39
- logger = logging.getLogger("eegdash")
38
+ from .utils import _init_mongo_client
40
39
 
41
40
 
42
41
  class EEGDash:
@@ -74,19 +73,26 @@ class EEGDash:
74
73
 
75
74
  if self.is_public:
76
75
  DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
76
+ if not DB_CONNECTION_STRING:
77
+ try:
78
+ _init_mongo_client()
79
+ DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
80
+ except Exception:
81
+ DB_CONNECTION_STRING = None
77
82
  else:
78
83
  load_dotenv()
79
84
  DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
80
85
 
81
86
  # Use singleton to get MongoDB client, database, and collection
87
+ if not DB_CONNECTION_STRING:
88
+ raise RuntimeError(
89
+ "No MongoDB connection string configured. Set MNE config 'EEGDASH_DB_URI' "
90
+ "or environment variable 'DB_CONNECTION_STRING'."
91
+ )
82
92
  self.__client, self.__db, self.__collection = MongoConnectionManager.get_client(
83
93
  DB_CONNECTION_STRING, is_staging
84
94
  )
85
95
 
86
- self.filesystem = S3FileSystem(
87
- anon=True, client_kwargs={"region_name": "us-east-2"}
88
- )
89
-
90
96
  def find(
91
97
  self, query: dict[str, Any] = None, /, **kwargs
92
98
  ) -> list[Mapping[str, Any]]:
@@ -310,83 +316,6 @@ class EEGDash:
310
316
  f"Conflicting constraints for '{key}': disjoint sets {r_val!r} and {k_val!r}"
311
317
  )
312
318
 
313
- def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
314
- """Load EEG data from an S3 URI into an ``xarray.DataArray``.
315
-
316
- Preserves the original filename, downloads sidecar files when applicable
317
- (e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
318
- MNE's direct readers.
319
-
320
- Parameters
321
- ----------
322
- s3path : str
323
- An S3 URI (should start with "s3://").
324
-
325
- Returns
326
- -------
327
- xr.DataArray
328
- EEG data with dimensions ``("channel", "time")``.
329
-
330
- Raises
331
- ------
332
- ValueError
333
- If the file extension is unsupported.
334
-
335
- """
336
- # choose a temp dir so sidecars can be colocated
337
- with tempfile.TemporaryDirectory() as tmpdir:
338
- # Derive local filenames from the S3 key to keep base name consistent
339
- s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
340
- basename = Path(s3_key).name
341
- ext = Path(basename).suffix.lower()
342
- local_main = Path(tmpdir) / basename
343
-
344
- # Download main file
345
- with (
346
- self.filesystem.open(s3path, mode="rb") as fsrc,
347
- open(local_main, "wb") as fdst,
348
- ):
349
- fdst.write(fsrc.read())
350
-
351
- # Determine and fetch any required sidecars
352
- sidecars: list[str] = []
353
- if ext == ".set": # EEGLAB
354
- sidecars = [".fdt"]
355
- elif ext == ".vhdr": # BrainVision
356
- sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
357
-
358
- for sc_ext in sidecars:
359
- sc_key = s3_key[: -len(ext)] + sc_ext
360
- sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
361
- try:
362
- # If sidecar exists, download next to the main file
363
- info = self.filesystem.info(sc_uri)
364
- if info:
365
- sc_local = Path(tmpdir) / Path(sc_key).name
366
- with (
367
- self.filesystem.open(sc_uri, mode="rb") as fsrc,
368
- open(sc_local, "wb") as fdst,
369
- ):
370
- fdst.write(fsrc.read())
371
- except Exception:
372
- # Sidecar not present; skip silently
373
- pass
374
-
375
- # Read using appropriate MNE reader
376
- raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
377
-
378
- data = raw.get_data()
379
- fs = raw.info["sfreq"]
380
- max_time = data.shape[1] / fs
381
- time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
382
- channel_names = raw.ch_names
383
-
384
- return xr.DataArray(
385
- data=data,
386
- dims=["channel", "time"],
387
- coords={"time": time_steps, "channel": channel_names},
388
- )
389
-
390
319
  def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
391
320
  """Load EEG data from a local BIDS-formatted file.
392
321
 
@@ -508,39 +437,13 @@ class EEGDash:
508
437
  results = Parallel(
509
438
  n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
510
439
  )(
511
- delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
440
+ delayed(downloader.load_eeg_from_s3)(
441
+ downloader.get_s3path("s3://openneuro.org", session["bidspath"])
442
+ )
512
443
  for session in sessions
513
444
  )
514
445
  return results
515
446
 
516
- def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
517
- """Build an S3 URI from a DB record or a relative path.
518
-
519
- Parameters
520
- ----------
521
- record : dict or str
522
- Either a DB record containing a ``'bidspath'`` key, or a relative
523
- path string under the OpenNeuro bucket.
524
-
525
- Returns
526
- -------
527
- str
528
- Fully qualified S3 URI.
529
-
530
- Raises
531
- ------
532
- ValueError
533
- If a mapping is provided but ``'bidspath'`` is missing.
534
-
535
- """
536
- if isinstance(record, str):
537
- rel = record
538
- else:
539
- rel = record.get("bidspath")
540
- if not rel:
541
- raise ValueError("Record missing 'bidspath' for S3 path resolution")
542
- return f"s3://openneuro.org/{rel}"
543
-
544
447
  def _add_request(self, record: dict):
545
448
  """Internal helper method to create a MongoDB insertion request for a record."""
546
449
  return InsertOne(record)
@@ -552,8 +455,11 @@ class EEGDash:
552
455
  except ValueError as e:
553
456
  logger.error("Validation error for record: %s ", record["data_name"])
554
457
  logger.error(e)
555
- except:
556
- logger.error("Error adding record: %s ", record["data_name"])
458
+ except Exception as exc:
459
+ logger.error(
460
+ "Error adding record: %s ", record.get("data_name", "<unknown>")
461
+ )
462
+ logger.debug("Add operation failed", exc_info=exc)
557
463
 
558
464
  def _update_request(self, record: dict):
559
465
  """Internal helper method to create a MongoDB update request for a record."""
@@ -572,8 +478,11 @@ class EEGDash:
572
478
  self.__collection.update_one(
573
479
  {"data_name": record["data_name"]}, {"$set": record}
574
480
  )
575
- except: # silent failure
576
- logger.error("Error updating record: %s", record["data_name"])
481
+ except Exception as exc: # log and continue
482
+ logger.error(
483
+ "Error updating record: %s", record.get("data_name", "<unknown>")
484
+ )
485
+ logger.debug("Update operation failed", exc_info=exc)
577
486
 
578
487
  def exists(self, query: dict[str, Any]) -> bool:
579
488
  """Alias for :meth:`exist` provided for API clarity."""
@@ -654,8 +563,7 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM
654
563
  Parameters
655
564
  ----------
656
565
  cache_dir : str | Path
657
- Directory where data are cached locally. If not specified, a default
658
- cache directory under the user cache is used.
566
+ Directory where data are cached locally.
659
567
  query : dict | None
660
568
  Raw MongoDB query to filter records. If provided, it is merged with
661
569
  keyword filtering arguments (see ``**kwargs``) using logical AND.
@@ -726,13 +634,21 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM
726
634
  self.records = records
727
635
  self.download = download
728
636
  self.n_jobs = n_jobs
729
- self.eeg_dash_instance = eeg_dash_instance or EEGDash()
637
+ self.eeg_dash_instance = eeg_dash_instance
730
638
 
731
- # Resolve a unified cache directory across code/tests/CI
732
- self.cache_dir = Path(cache_dir or get_default_cache_dir())
639
+ self.cache_dir = cache_dir
640
+ if self.cache_dir == "" or self.cache_dir is None:
641
+ self.cache_dir = get_default_cache_dir()
642
+ logger.warning(
643
+ f"Cache directory is empty, using the eegdash default path: {self.cache_dir}"
644
+ )
645
+
646
+ self.cache_dir = Path(self.cache_dir)
733
647
 
734
648
  if not self.cache_dir.exists():
735
- warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
649
+ logger.warning(
650
+ f"Cache directory does not exist, creating it: {self.cache_dir}"
651
+ )
736
652
  self.cache_dir.mkdir(exist_ok=True, parents=True)
737
653
 
738
654
  # Separate query kwargs from other kwargs passed to the BaseDataset constructor
@@ -772,21 +688,29 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM
772
688
  not _suppress_comp_warning
773
689
  and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
774
690
  ):
775
- warn(
776
- "If you are not participating in the competition, you can ignore this warning!"
777
- "\n\n"
778
- "EEG 2025 Competition Data Notice:\n"
779
- "---------------------------------\n"
780
- " You are loading the dataset that is used in the EEG 2025 Competition:\n"
781
- "IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n"
782
- "and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
783
- "to allow more people to participate.\n"
784
- "\n"
785
- "If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
786
- "\n",
787
- UserWarning,
788
- module="eegdash",
691
+ message_text = Text.from_markup(
692
+ "[italic]This notice is only for users who are participating in the [link=https://eeg2025.github.io/]EEG 2025 Competition[/link].[/italic]\n\n"
693
+ "[bold]EEG 2025 Competition Data Notice![/bold]\n"
694
+ "You are loading one of the datasets that is used in competition, but via `EEGDashDataset`.\n\n"
695
+ "[bold red]IMPORTANT[/bold red]: \n"
696
+ "If you download data from `EEGDashDataset`, it is [u]NOT[/u] identical to the official \n"
697
+ "competition data, which is accessed via `EEGChallengeDataset`. "
698
+ "The competition data has been downsampled and filtered.\n\n"
699
+ "[bold]If you are participating in the competition, \nyou must use the `EEGChallengeDataset` object to ensure consistency.[/bold] \n\n"
700
+ "If you are not participating in the competition, you can ignore this message."
789
701
  )
702
+ warning_panel = Panel(
703
+ message_text,
704
+ title="[yellow]EEG 2025 Competition Data Notice[/yellow]",
705
+ subtitle="[cyan]Source: EEGDashDataset[/cyan]",
706
+ border_style="yellow",
707
+ )
708
+
709
+ try:
710
+ Console().print(warning_panel)
711
+ except Exception:
712
+ logger.warning(str(message_text))
713
+
790
714
  if records is not None:
791
715
  self.records = records
792
716
  datasets = [
@@ -848,16 +772,15 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM
848
772
  )
849
773
  )
850
774
  elif self.query:
851
- # This is the DB query path that we are improving
775
+ if self.eeg_dash_instance is None:
776
+ self.eeg_dash_instance = EEGDash()
852
777
  datasets = self._find_datasets(
853
778
  query=build_query_from_kwargs(**self.query),
854
779
  description_fields=description_fields,
855
780
  base_dataset_kwargs=base_dataset_kwargs,
856
781
  )
857
782
  # We only need filesystem if we need to access S3
858
- self.filesystem = S3FileSystem(
859
- anon=True, client_kwargs={"region_name": "us-east-2"}
860
- )
783
+ self.filesystem = downloader.get_s3_filesystem()
861
784
  else:
862
785
  raise ValueError(
863
786
  "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
@@ -1,18 +1,23 @@
1
- import logging
2
1
  import re
3
2
  from pathlib import Path
4
3
  from typing import Any
5
4
 
5
+ import pandas as pd
6
+ from mne_bids import BIDSPath
7
+
6
8
  from .const import ALLOWED_QUERY_FIELDS
7
9
  from .const import config as data_config
8
-
9
- logger = logging.getLogger("eegdash")
10
+ from .logging import logger
10
11
 
11
12
  __all__ = [
12
13
  "build_query_from_kwargs",
13
14
  "load_eeg_attrs_from_bids_file",
14
15
  "merge_participants_fields",
15
16
  "normalize_key",
17
+ "participants_row_for_subject",
18
+ "participants_extras_from_tsv",
19
+ "attach_participants_extras",
20
+ "enrich_from_participants",
16
21
  ]
17
22
 
18
23
 
@@ -72,28 +77,6 @@ def build_query_from_kwargs(**kwargs) -> dict[str, Any]:
72
77
  return query
73
78
 
74
79
 
75
- def _get_raw_extensions(bids_file: str, bids_dataset) -> list[str]:
76
- """Helper to find paths to additional "sidecar" files that may be associated
77
- with a given main data file in a BIDS dataset; paths are returned as relative to
78
- the parent dataset path.
79
-
80
- For example, if the input file is a .set file, this will return the relative path
81
- to a corresponding .fdt file (if any).
82
- """
83
- bids_file = Path(bids_file)
84
- extensions = {
85
- ".set": [".set", ".fdt"], # eeglab
86
- ".edf": [".edf"], # european
87
- ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
88
- ".bdf": [".bdf"], # biosemi
89
- }
90
- return [
91
- str(bids_dataset._get_relative_bidspath(bids_file.with_suffix(suffix)))
92
- for suffix in extensions[bids_file.suffix]
93
- if bids_file.with_suffix(suffix).exists()
94
- ]
95
-
96
-
97
80
  def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any]:
98
81
  """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
99
82
 
@@ -140,7 +123,7 @@ def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any
140
123
  eeg_json = None
141
124
 
142
125
  bids_dependencies_files = data_config["bids_dependencies_files"]
143
- bidsdependencies = []
126
+ bidsdependencies: list[str] = []
144
127
  for extension in bids_dependencies_files:
145
128
  try:
146
129
  dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
@@ -151,7 +134,26 @@ def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any
151
134
  except Exception:
152
135
  pass
153
136
 
154
- bidsdependencies.extend(_get_raw_extensions(bids_file, bids_dataset))
137
+ bids_path = BIDSPath(
138
+ subject=bids_dataset.get_bids_file_attribute("subject", bids_file),
139
+ session=bids_dataset.get_bids_file_attribute("session", bids_file),
140
+ task=bids_dataset.get_bids_file_attribute("task", bids_file),
141
+ run=bids_dataset.get_bids_file_attribute("run", bids_file),
142
+ root=bids_dataset.bidsdir,
143
+ datatype=bids_dataset.get_bids_file_attribute("modality", bids_file),
144
+ suffix="eeg",
145
+ extension=Path(bids_file).suffix,
146
+ check=False,
147
+ )
148
+
149
+ sidecars_map = {
150
+ ".set": [".fdt"],
151
+ ".vhdr": [".eeg", ".vmrk", ".dat", ".raw"],
152
+ }
153
+ for ext in sidecars_map.get(bids_path.extension, []):
154
+ sidecar = bids_path.find_matching_sidecar(extension=ext, on_error="ignore")
155
+ if sidecar is not None:
156
+ bidsdependencies.append(str(bids_dataset._get_relative_bidspath(sidecar)))
155
157
 
156
158
  # Define field extraction functions with error handling
157
159
  field_extractors = {
@@ -252,3 +254,123 @@ def merge_participants_fields(
252
254
  if norm_key not in description:
253
255
  description[norm_key] = part_value
254
256
  return description
257
+
258
+
259
+ def participants_row_for_subject(
260
+ bids_root: str | Path,
261
+ subject: str,
262
+ id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"),
263
+ ) -> pd.Series | None:
264
+ """Load participants.tsv and return the row for a subject.
265
+
266
+ - Accepts either "01" or "sub-01" as the subject identifier.
267
+ - Returns a pandas Series for the first matching row, or None if not found.
268
+ """
269
+ try:
270
+ participants_tsv = Path(bids_root) / "participants.tsv"
271
+ if not participants_tsv.exists():
272
+ return None
273
+
274
+ df = pd.read_csv(
275
+ participants_tsv, sep="\t", dtype="string", keep_default_na=False
276
+ )
277
+ if df.empty:
278
+ return None
279
+
280
+ candidates = {str(subject), f"sub-{subject}"}
281
+ present_cols = [c for c in id_columns if c in df.columns]
282
+ if not present_cols:
283
+ return None
284
+
285
+ mask = pd.Series(False, index=df.index)
286
+ for col in present_cols:
287
+ mask |= df[col].isin(candidates)
288
+ match = df.loc[mask]
289
+ if match.empty:
290
+ return None
291
+ return match.iloc[0]
292
+ except Exception:
293
+ return None
294
+
295
+
296
+ def participants_extras_from_tsv(
297
+ bids_root: str | Path,
298
+ subject: str,
299
+ *,
300
+ id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"),
301
+ na_like: tuple[str, ...] = ("", "n/a", "na", "nan", "unknown", "none"),
302
+ ) -> dict[str, Any]:
303
+ """Return non-identifier, non-empty participants.tsv fields for a subject.
304
+
305
+ Uses vectorized pandas operations to drop id columns and NA-like values.
306
+ """
307
+ row = participants_row_for_subject(bids_root, subject, id_columns=id_columns)
308
+ if row is None:
309
+ return {}
310
+
311
+ # Drop identifier columns and clean values
312
+ extras = row.drop(labels=[c for c in id_columns if c in row.index], errors="ignore")
313
+ s = extras.astype("string").str.strip()
314
+ valid = ~s.isna() & ~s.str.lower().isin(na_like)
315
+ return s[valid].to_dict()
316
+
317
+
318
+ def attach_participants_extras(
319
+ raw: Any,
320
+ description: Any,
321
+ extras: dict[str, Any],
322
+ ) -> None:
323
+ """Attach extras to Raw.info and dataset description without overwriting.
324
+
325
+ - Adds to ``raw.info['subject_info']['participants_extras']``.
326
+ - Adds to ``description`` if dict or pandas Series (only missing keys).
327
+ """
328
+ if not extras:
329
+ return
330
+
331
+ # Raw.info enrichment
332
+ try:
333
+ subject_info = raw.info.get("subject_info") or {}
334
+ if not isinstance(subject_info, dict):
335
+ subject_info = {}
336
+ pe = subject_info.get("participants_extras") or {}
337
+ if not isinstance(pe, dict):
338
+ pe = {}
339
+ for k, v in extras.items():
340
+ pe.setdefault(k, v)
341
+ subject_info["participants_extras"] = pe
342
+ raw.info["subject_info"] = subject_info
343
+ except Exception:
344
+ pass
345
+
346
+ # Description enrichment
347
+ try:
348
+ import pandas as _pd # local import to avoid hard dependency at import time
349
+
350
+ if isinstance(description, dict):
351
+ for k, v in extras.items():
352
+ description.setdefault(k, v)
353
+ elif isinstance(description, _pd.Series):
354
+ missing = [k for k in extras.keys() if k not in description.index]
355
+ if missing:
356
+ description.loc[missing] = [extras[m] for m in missing]
357
+ except Exception:
358
+ pass
359
+
360
+
361
+ def enrich_from_participants(
362
+ bids_root: str | Path,
363
+ bidspath: BIDSPath,
364
+ raw: Any,
365
+ description: Any,
366
+ ) -> dict[str, Any]:
367
+ """Convenience wrapper: read participants.tsv and attach extras for this subject.
368
+
369
+ Returns the extras dictionary for further use if needed.
370
+ """
371
+ subject = getattr(bidspath, "subject", None)
372
+ if not subject:
373
+ return {}
374
+ extras = participants_extras_from_tsv(bids_root, subject)
375
+ attach_participants_extras(raw, description, extras)
376
+ return extras