eegdash 0.4.1__py3-none-any.whl → 0.4.1.dev185__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/__init__.py CHANGED
@@ -9,8 +9,8 @@ EEG datasets. It integrates with cloud storage, MongoDB databases, and machine l
9
9
  frameworks to streamline EEG research workflows.
10
10
  """
11
11
 
12
- from .api import EEGDash, EEGDashDataset
13
- from .dataset import EEGChallengeDataset
12
+ from .api import EEGDash
13
+ from .dataset import EEGChallengeDataset, EEGDashDataset
14
14
  from .hbn import preprocessing
15
15
  from .utils import _init_mongo_client
16
16
 
@@ -18,4 +18,4 @@ _init_mongo_client()
18
18
 
19
19
  __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset", "preprocessing"]
20
20
 
21
- __version__ = "0.4.1"
21
+ __version__ = "0.4.1.dev185"
eegdash/api.py CHANGED
@@ -15,35 +15,20 @@ from pathlib import Path
15
15
  from typing import Any, Mapping
16
16
 
17
17
  import mne
18
- from docstring_inheritance import NumpyDocstringInheritanceInitMeta
19
18
  from mne.utils import _soft_import
20
- from mne_bids import find_matching_paths
21
19
  from pymongo import InsertOne, UpdateOne
22
- from rich.console import Console
23
- from rich.panel import Panel
24
- from rich.text import Text
25
20
 
26
- from braindecode.datasets import BaseConcatDataset
27
-
28
- from . import downloader
29
21
  from .bids_eeg_metadata import (
30
22
  build_query_from_kwargs,
31
23
  load_eeg_attrs_from_bids_file,
32
- merge_participants_fields,
33
- normalize_key,
34
24
  )
35
25
  from .const import (
36
26
  ALLOWED_QUERY_FIELDS,
37
- RELEASE_TO_OPENNEURO_DATASET_MAP,
38
27
  )
39
28
  from .const import config as data_config
40
- from .data_utils import (
41
- EEGBIDSDataset,
42
- EEGDashBaseDataset,
43
- )
29
+ from .dataset.bids_dataset import EEGBIDSDataset
44
30
  from .logging import logger
45
31
  from .mongodb import MongoConnectionManager
46
- from .paths import get_default_cache_dir
47
32
  from .utils import _init_mongo_client
48
33
 
49
34
 
@@ -582,467 +567,4 @@ class EEGDash:
582
567
  pass
583
568
 
584
569
 
585
- class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitMeta):
586
- """Create a new EEGDashDataset from a given query or local BIDS dataset directory
587
- and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
588
- instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
589
-
590
- Examples
591
- --------
592
- Basic usage with dataset and subject filtering:
593
-
594
- >>> from eegdash import EEGDashDataset
595
- >>> dataset = EEGDashDataset(
596
- ... cache_dir="./data",
597
- ... dataset="ds002718",
598
- ... subject="012"
599
- ... )
600
- >>> print(f"Number of recordings: {len(dataset)}")
601
-
602
- Filter by multiple subjects and specific task:
603
-
604
- >>> subjects = ["012", "013", "014"]
605
- >>> dataset = EEGDashDataset(
606
- ... cache_dir="./data",
607
- ... dataset="ds002718",
608
- ... subject=subjects,
609
- ... task="RestingState"
610
- ... )
611
-
612
- Load and inspect EEG data from recordings:
613
-
614
- >>> if len(dataset) > 0:
615
- ... recording = dataset[0]
616
- ... raw = recording.load()
617
- ... print(f"Sampling rate: {raw.info['sfreq']} Hz")
618
- ... print(f"Number of channels: {len(raw.ch_names)}")
619
- ... print(f"Duration: {raw.times[-1]:.1f} seconds")
620
-
621
- Advanced filtering with raw MongoDB queries:
622
-
623
- >>> from eegdash import EEGDashDataset
624
- >>> query = {
625
- ... "dataset": "ds002718",
626
- ... "subject": {"$in": ["012", "013"]},
627
- ... "task": "RestingState"
628
- ... }
629
- >>> dataset = EEGDashDataset(cache_dir="./data", query=query)
630
-
631
- Working with dataset collections and braindecode integration:
632
-
633
- >>> # EEGDashDataset is a braindecode BaseConcatDataset
634
- >>> for i, recording in enumerate(dataset):
635
- ... if i >= 2: # limit output
636
- ... break
637
- ... print(f"Recording {i}: {recording.description}")
638
- ... raw = recording.load()
639
- ... print(f" Channels: {len(raw.ch_names)}, Duration: {raw.times[-1]:.1f}s")
640
-
641
- Parameters
642
- ----------
643
- cache_dir : str | Path
644
- Directory where data are cached locally.
645
- query : dict | None
646
- Raw MongoDB query to filter records. If provided, it is merged with
647
- keyword filtering arguments (see ``**kwargs``) using logical AND.
648
- You must provide at least a ``dataset`` (either in ``query`` or
649
- as a keyword argument). Only fields in ``ALLOWED_QUERY_FIELDS`` are
650
- considered for filtering.
651
- dataset : str
652
- Dataset identifier (e.g., ``"ds002718"``). Required if ``query`` does
653
- not already specify a dataset.
654
- task : str | list[str]
655
- Task name(s) to filter by (e.g., ``"RestingState"``).
656
- subject : str | list[str]
657
- Subject identifier(s) to filter by (e.g., ``"NDARCA153NKE"``).
658
- session : str | list[str]
659
- Session identifier(s) to filter by (e.g., ``"1"``).
660
- run : str | list[str]
661
- Run identifier(s) to filter by (e.g., ``"1"``).
662
- description_fields : list[str]
663
- Fields to extract from each record and include in dataset descriptions
664
- (e.g., "subject", "session", "run", "task").
665
- s3_bucket : str | None
666
- Optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
667
- default OpenNeuro bucket when downloading data files.
668
- records : list[dict] | None
669
- Pre-fetched metadata records. If provided, the dataset is constructed
670
- directly from these records and no MongoDB query is performed.
671
- download : bool, default True
672
- If False, load from local BIDS files only. Local data are expected
673
- under ``cache_dir / dataset``; no DB or S3 access is attempted.
674
- n_jobs : int
675
- Number of parallel jobs to use where applicable (-1 uses all cores).
676
- eeg_dash_instance : EEGDash | None
677
- Optional existing EEGDash client to reuse for DB queries. If None,
678
- a new client is created on demand, not used in the case of no download.
679
- **kwargs : dict
680
- Additional keyword arguments serving two purposes:
681
-
682
- - Filtering: any keys present in ``ALLOWED_QUERY_FIELDS`` are treated as
683
- query filters (e.g., ``dataset``, ``subject``, ``task``, ...).
684
- - Dataset options: remaining keys are forwarded to
685
- ``EEGDashBaseDataset``.
686
-
687
- """
688
-
689
- def __init__(
690
- self,
691
- cache_dir: str | Path,
692
- query: dict[str, Any] = None,
693
- description_fields: list[str] = [
694
- "subject",
695
- "session",
696
- "run",
697
- "task",
698
- "age",
699
- "gender",
700
- "sex",
701
- ],
702
- s3_bucket: str | None = None,
703
- records: list[dict] | None = None,
704
- download: bool = True,
705
- n_jobs: int = -1,
706
- eeg_dash_instance: EEGDash | None = None,
707
- **kwargs,
708
- ):
709
- # Parameters that don't need validation
710
- _suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
711
- self.s3_bucket = s3_bucket
712
- self.records = records
713
- self.download = download
714
- self.n_jobs = n_jobs
715
- self.eeg_dash_instance = eeg_dash_instance
716
-
717
- self.cache_dir = cache_dir
718
- if self.cache_dir == "" or self.cache_dir is None:
719
- self.cache_dir = get_default_cache_dir()
720
- logger.warning(
721
- f"Cache directory is empty, using the eegdash default path: {self.cache_dir}"
722
- )
723
-
724
- self.cache_dir = Path(self.cache_dir)
725
-
726
- if not self.cache_dir.exists():
727
- logger.warning(
728
- f"Cache directory does not exist, creating it: {self.cache_dir}"
729
- )
730
- self.cache_dir.mkdir(exist_ok=True, parents=True)
731
-
732
- # Separate query kwargs from other kwargs passed to the BaseDataset constructor
733
- self.query = query or {}
734
- self.query.update(
735
- {k: v for k, v in kwargs.items() if k in ALLOWED_QUERY_FIELDS}
736
- )
737
- base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in self.query}
738
- if "dataset" not in self.query:
739
- # If explicit records are provided, infer dataset from records
740
- if isinstance(records, list) and records and isinstance(records[0], dict):
741
- inferred = records[0].get("dataset")
742
- if inferred:
743
- self.query["dataset"] = inferred
744
- else:
745
- raise ValueError("You must provide a 'dataset' argument")
746
- else:
747
- raise ValueError("You must provide a 'dataset' argument")
748
-
749
- # Decide on a dataset subfolder name for cache isolation. If using
750
- # challenge/preprocessed buckets (e.g., BDF, mini subsets), append
751
- # informative suffixes to avoid overlapping with the original dataset.
752
- dataset_folder = self.query["dataset"]
753
- if self.s3_bucket:
754
- suffixes: list[str] = []
755
- bucket_lower = str(self.s3_bucket).lower()
756
- if "bdf" in bucket_lower:
757
- suffixes.append("bdf")
758
- if "mini" in bucket_lower:
759
- suffixes.append("mini")
760
- if suffixes:
761
- dataset_folder = f"{dataset_folder}-{'-'.join(suffixes)}"
762
-
763
- self.data_dir = self.cache_dir / dataset_folder
764
-
765
- if (
766
- not _suppress_comp_warning
767
- and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
768
- ):
769
- message_text = Text.from_markup(
770
- "[italic]This notice is only for users who are participating in the [link=https://eeg2025.github.io/]EEG 2025 Competition[/link].[/italic]\n\n"
771
- "[bold]EEG 2025 Competition Data Notice![/bold]\n"
772
- "You are loading one of the datasets that is used in competition, but via `EEGDashDataset`.\n\n"
773
- "[bold red]IMPORTANT[/bold red]: \n"
774
- "If you download data from `EEGDashDataset`, it is [u]NOT[/u] identical to the official \n"
775
- "competition data, which is accessed via `EEGChallengeDataset`. "
776
- "The competition data has been downsampled and filtered.\n\n"
777
- "[bold]If you are participating in the competition, \nyou must use the `EEGChallengeDataset` object to ensure consistency.[/bold] \n\n"
778
- "If you are not participating in the competition, you can ignore this message."
779
- )
780
- warning_panel = Panel(
781
- message_text,
782
- title="[yellow]EEG 2025 Competition Data Notice[/yellow]",
783
- subtitle="[cyan]Source: EEGDashDataset[/cyan]",
784
- border_style="yellow",
785
- )
786
-
787
- try:
788
- Console().print(warning_panel)
789
- except Exception:
790
- logger.warning(str(message_text))
791
-
792
- if records is not None:
793
- self.records = records
794
- datasets = [
795
- EEGDashBaseDataset(
796
- record,
797
- self.cache_dir,
798
- self.s3_bucket,
799
- **base_dataset_kwargs,
800
- )
801
- for record in self.records
802
- ]
803
- elif not download: # only assume local data is complete if not downloading
804
- if not self.data_dir.exists():
805
- raise ValueError(
806
- f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
807
- )
808
- records = self._find_local_bids_records(self.data_dir, self.query)
809
- # Try to enrich from local participants.tsv to restore requested fields
810
- try:
811
- bids_ds = EEGBIDSDataset(
812
- data_dir=str(self.data_dir), dataset=self.query["dataset"]
813
- ) # type: ignore[index]
814
- except Exception:
815
- bids_ds = None
816
-
817
- datasets = []
818
- for record in records:
819
- # Start with entity values from filename
820
- desc: dict[str, Any] = {
821
- k: record.get(k)
822
- for k in ("subject", "session", "run", "task")
823
- if record.get(k) is not None
824
- }
825
-
826
- if bids_ds is not None:
827
- try:
828
- rel_from_dataset = Path(record["bidspath"]).relative_to(
829
- record["dataset"]
830
- ) # type: ignore[index]
831
- local_file = (self.data_dir / rel_from_dataset).as_posix()
832
- part_row = bids_ds.subject_participant_tsv(local_file)
833
- desc = merge_participants_fields(
834
- description=desc,
835
- participants_row=part_row
836
- if isinstance(part_row, dict)
837
- else None,
838
- description_fields=description_fields,
839
- )
840
- except Exception:
841
- pass
842
-
843
- datasets.append(
844
- EEGDashBaseDataset(
845
- record=record,
846
- cache_dir=self.cache_dir,
847
- s3_bucket=self.s3_bucket,
848
- description=desc,
849
- **base_dataset_kwargs,
850
- )
851
- )
852
- elif self.query:
853
- if self.eeg_dash_instance is None:
854
- self.eeg_dash_instance = EEGDash()
855
- datasets = self._find_datasets(
856
- query=build_query_from_kwargs(**self.query),
857
- description_fields=description_fields,
858
- base_dataset_kwargs=base_dataset_kwargs,
859
- )
860
- # We only need filesystem if we need to access S3
861
- self.filesystem = downloader.get_s3_filesystem()
862
- else:
863
- raise ValueError(
864
- "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
865
- )
866
-
867
- super().__init__(datasets)
868
-
869
- def _find_local_bids_records(
870
- self, dataset_root: Path, filters: dict[str, Any]
871
- ) -> list[dict]:
872
- """Discover local BIDS EEG files and build minimal records.
873
-
874
- Enumerates EEG recordings under ``dataset_root`` using
875
- ``mne_bids.find_matching_paths`` and applies entity filters to produce
876
- records suitable for :class:`EEGDashBaseDataset`. No network access is
877
- performed, and files are not read.
878
-
879
- Parameters
880
- ----------
881
- dataset_root : Path
882
- Local dataset directory (e.g., ``/path/to/cache/ds005509``).
883
- filters : dict
884
- Query filters. Must include ``'dataset'`` and may include BIDS
885
- entities like ``'subject'``, ``'session'``, etc.
886
-
887
- Returns
888
- -------
889
- list of dict
890
- A list of records, one for each matched EEG file. Each record
891
- contains BIDS entities, paths, and minimal metadata for offline use.
892
-
893
- Notes
894
- -----
895
- Matching is performed for ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
896
- The ``bidspath`` is normalized to ensure it starts with the dataset ID,
897
- even for suffixed cache directories.
898
-
899
- """
900
- dataset_id = filters["dataset"]
901
- arg_map = {
902
- "subjects": "subject",
903
- "sessions": "session",
904
- "tasks": "task",
905
- "runs": "run",
906
- }
907
- matching_args: dict[str, list[str]] = {}
908
- for finder_key, entity_key in arg_map.items():
909
- entity_val = filters.get(entity_key)
910
- if entity_val is None:
911
- continue
912
- if isinstance(entity_val, (list, tuple, set)):
913
- entity_vals = list(entity_val)
914
- if not entity_vals:
915
- continue
916
- matching_args[finder_key] = entity_vals
917
- else:
918
- matching_args[finder_key] = [entity_val]
919
-
920
- matched_paths = find_matching_paths(
921
- root=str(dataset_root),
922
- datatypes=["eeg"],
923
- suffixes=["eeg"],
924
- ignore_json=True,
925
- **matching_args,
926
- )
927
- records_out: list[dict] = []
928
-
929
- for bids_path in matched_paths:
930
- # Build bidspath as dataset_id / relative_path_from_dataset_root (POSIX)
931
- rel_from_root = (
932
- Path(bids_path.fpath)
933
- .resolve()
934
- .relative_to(Path(bids_path.root).resolve())
935
- )
936
- bidspath = f"{dataset_id}/{rel_from_root.as_posix()}"
937
-
938
- rec = {
939
- "data_name": f"{dataset_id}_{Path(bids_path.fpath).name}",
940
- "dataset": dataset_id,
941
- "bidspath": bidspath,
942
- "subject": (bids_path.subject or None),
943
- "session": (bids_path.session or None),
944
- "task": (bids_path.task or None),
945
- "run": (bids_path.run or None),
946
- # minimal fields to satisfy BaseDataset from eegdash
947
- "bidsdependencies": [], # not needed to just run.
948
- "modality": "eeg",
949
- # minimal numeric defaults for offline length calculation
950
- "sampling_frequency": None,
951
- "nchans": None,
952
- "ntimes": None,
953
- }
954
- records_out.append(rec)
955
-
956
- return records_out
957
-
958
- def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
959
- """Recursively search for a key in nested dicts/lists.
960
-
961
- Performs a case-insensitive and underscore/hyphen-agnostic search.
962
-
963
- Parameters
964
- ----------
965
- data : Any
966
- The nested data structure (dicts, lists) to search.
967
- target_key : str
968
- The key to search for.
969
-
970
- Returns
971
- -------
972
- Any
973
- The value of the first matching key, or None if not found.
974
-
975
- """
976
- norm_target = normalize_key(target_key)
977
- if isinstance(data, dict):
978
- for k, v in data.items():
979
- if normalize_key(k) == norm_target:
980
- return v
981
- res = self._find_key_in_nested_dict(v, target_key)
982
- if res is not None:
983
- return res
984
- elif isinstance(data, list):
985
- for item in data:
986
- res = self._find_key_in_nested_dict(item, target_key)
987
- if res is not None:
988
- return res
989
- return None
990
-
991
- def _find_datasets(
992
- self,
993
- query: dict[str, Any] | None,
994
- description_fields: list[str],
995
- base_dataset_kwargs: dict,
996
- ) -> list[EEGDashBaseDataset]:
997
- """Find and construct datasets from a MongoDB query.
998
-
999
- Queries the database, then creates a list of
1000
- :class:`EEGDashBaseDataset` objects from the results.
1001
-
1002
- Parameters
1003
- ----------
1004
- query : dict, optional
1005
- The MongoDB query to execute.
1006
- description_fields : list of str
1007
- Fields to extract from each record for the dataset description.
1008
- base_dataset_kwargs : dict
1009
- Additional keyword arguments to pass to the
1010
- :class:`EEGDashBaseDataset` constructor.
1011
-
1012
- Returns
1013
- -------
1014
- list of EEGDashBaseDataset
1015
- A list of dataset objects matching the query.
1016
-
1017
- """
1018
- datasets: list[EEGDashBaseDataset] = []
1019
- self.records = self.eeg_dash_instance.find(query)
1020
-
1021
- for record in self.records:
1022
- description: dict[str, Any] = {}
1023
- # Requested fields first (normalized matching)
1024
- for field in description_fields:
1025
- value = self._find_key_in_nested_dict(record, field)
1026
- if value is not None:
1027
- description[field] = value
1028
- # Merge all participants.tsv columns generically
1029
- part = self._find_key_in_nested_dict(record, "participant_tsv")
1030
- if isinstance(part, dict):
1031
- description = merge_participants_fields(
1032
- description=description,
1033
- participants_row=part,
1034
- description_fields=description_fields,
1035
- )
1036
- datasets.append(
1037
- EEGDashBaseDataset(
1038
- record,
1039
- cache_dir=self.cache_dir,
1040
- s3_bucket=self.s3_bucket,
1041
- description=description,
1042
- **base_dataset_kwargs,
1043
- )
1044
- )
1045
- return datasets
1046
-
1047
-
1048
- __all__ = ["EEGDash", "EEGDashDataset"]
570
+ __all__ = ["EEGDash"]
@@ -1,7 +1,8 @@
1
1
  """Public API for dataset helpers and dynamically generated datasets."""
2
2
 
3
3
  from . import dataset as _dataset_mod # triggers dynamic class registration
4
- from .dataset import EEGChallengeDataset
4
+ from .bids_dataset import EEGBIDSDataset
5
+ from .dataset import EEGChallengeDataset, EEGDashDataset
5
6
  from .registry import register_openneuro_datasets
6
7
 
7
8
  # Re-export dynamically generated dataset classes at the package level so that
@@ -17,6 +18,11 @@ for _name in getattr(_dataset_mod, "__all__", []):
17
18
  globals()[_name] = _obj
18
19
  _dyn_names.append(_name)
19
20
 
20
- __all__ = ["EEGChallengeDataset", "register_openneuro_datasets"] + _dyn_names
21
+ __all__ = [
22
+ "EEGBIDSDataset",
23
+ "EEGDashDataset",
24
+ "EEGChallengeDataset",
25
+ "register_openneuro_datasets",
26
+ ] + _dyn_names
21
27
 
22
28
  del _dataset_mod, _name, _obj, _dyn_names