eegdash 0.3.4.dev70__py3-none-any.whl → 0.3.5.dev80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/__init__.py CHANGED
@@ -7,4 +7,4 @@ __init__mongo_client()
7
7
 
8
8
  __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset"]
9
9
 
10
- __version__ = "0.3.4.dev70"
10
+ __version__ = "0.3.5.dev80"
eegdash/api.py CHANGED
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import xarray as xr
10
10
  from dotenv import load_dotenv
11
11
  from joblib import Parallel, delayed
12
+ from mne_bids import get_bids_path_from_fname, read_raw_bids
12
13
  from pymongo import InsertOne, UpdateOne
13
14
  from s3fs import S3FileSystem
14
15
 
@@ -34,6 +35,19 @@ class EEGDash:
34
35
 
35
36
  """
36
37
 
38
+ _ALLOWED_QUERY_FIELDS = {
39
+ "data_name",
40
+ "dataset",
41
+ "subject",
42
+ "task",
43
+ "session",
44
+ "run",
45
+ "modality",
46
+ "sampling_frequency",
47
+ "nchans",
48
+ "ntimes",
49
+ }
50
+
37
51
  def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
38
52
  """Create new instance of the EEGDash Database client.
39
53
 
@@ -71,34 +85,59 @@ class EEGDash:
71
85
  anon=True, client_kwargs={"region_name": "us-east-2"}
72
86
  )
73
87
 
74
- def find(self, query: dict[str, Any], *args, **kwargs) -> list[Mapping[str, Any]]:
75
- """Find records in the MongoDB collection that satisfy the given query.
88
+ def find(
89
+ self, query: dict[str, Any] = None, /, **kwargs
90
+ ) -> list[Mapping[str, Any]]:
91
+ """Find records in the MongoDB collection.
92
+
93
+ This method can be called in two ways:
94
+ 1. With a pre-built MongoDB query dictionary (positional argument):
95
+ >>> eegdash.find({"dataset": "ds002718", "subject": {"$in": ["012", "013"]}})
96
+ 2. With user-friendly keyword arguments for simple and multi-value queries:
97
+ >>> eegdash.find(dataset="ds002718", subject="012")
98
+ >>> eegdash.find(dataset="ds002718", subject=["012", "013"])
76
99
 
77
100
  Parameters
78
101
  ----------
79
- query: dict
80
- A dictionary that specifies the query to be executed; this is a reference
81
- document that is used to match records in the MongoDB collection.
82
- args:
83
- Additional positional arguments for the MongoDB find() method; see
84
- https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
85
- kwargs:
86
- Additional keyword arguments for the MongoDB find() method.
102
+ query: dict, optional
103
+ A complete MongoDB query dictionary. This is a positional-only argument.
104
+ **kwargs:
105
+ Keyword arguments representing field-value pairs for the query.
106
+ Values can be single items (str, int) or lists of items for multi-search.
87
107
 
88
108
  Returns
89
109
  -------
90
110
  list:
91
111
  A list of DB records (string-keyed dictionaries) that match the query.
92
112
 
93
- Example
94
- -------
95
- >>> eegdash = EEGDash()
96
- >>> eegdash.find({"dataset": "ds002718", "subject": "012"})
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If both a `query` dictionary and keyword arguments are provided.
97
117
 
98
118
  """
99
- results = self.__collection.find(query, *args, **kwargs)
119
+ if query is not None and kwargs:
120
+ raise ValueError(
121
+ "Provide either a positional 'query' dictionary or keyword arguments, not both."
122
+ )
100
123
 
101
- return [result for result in results]
124
+ final_query = {}
125
+ if query is not None:
126
+ final_query = query
127
+ elif kwargs:
128
+ final_query = self._build_query_from_kwargs(**kwargs)
129
+ else:
130
+ # By default, an empty query {} returns all documents.
131
+ # This can be dangerous, so we can either allow it or raise an error.
132
+ # Let's require an explicit query for safety.
133
+ raise ValueError(
134
+ "find() requires a query dictionary or at least one keyword argument. "
135
+ "To find all documents, use find({})."
136
+ )
137
+
138
+ results = self.__collection.find(final_query)
139
+
140
+ return list(results)
102
141
 
103
142
  def exist(self, query: dict[str, Any]) -> bool:
104
143
  """Return True if at least one record matches the query, else False.
@@ -184,6 +223,35 @@ class EEGDash:
184
223
 
185
224
  return record
186
225
 
226
+ def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
227
+ """Builds and validates a MongoDB query from user-friendly keyword arguments.
228
+
229
+ Translates list values into MongoDB's `$in` operator.
230
+ """
231
+ # 1. Validate that all provided keys are allowed for querying
232
+ unknown_fields = set(kwargs.keys()) - self._ALLOWED_QUERY_FIELDS
233
+ if unknown_fields:
234
+ raise ValueError(
235
+ f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
236
+ f"Allowed fields are: {', '.join(sorted(self._ALLOWED_QUERY_FIELDS))}"
237
+ )
238
+
239
+ # 2. Construct the query dictionary
240
+ query = {}
241
+ for key, value in kwargs.items():
242
+ if isinstance(value, (list, tuple)):
243
+ if not value:
244
+ raise ValueError(
245
+ f"Received an empty list for query parameter '{key}'. This is not supported."
246
+ )
247
+ # If the value is a list, use the `$in` operator for multi-search
248
+ query[key] = {"$in": value}
249
+ else:
250
+ # Otherwise, it's a direct match
251
+ query[key] = value
252
+
253
+ return query
254
+
187
255
  def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
188
256
  """Load an EEGLAB .set file from an AWS S3 URI and return it as an xarray DataArray.
189
257
 
@@ -218,14 +286,15 @@ class EEGDash:
218
286
  Parameters
219
287
  ----------
220
288
  bids_file : str
221
- Path to the file on the local filesystem.
289
+ Path to the BIDS-compliant file on the local filesystem.
222
290
 
223
291
  Notes
224
292
  -----
225
293
  Currently, only non-epoched .set files are supported.
226
294
 
227
295
  """
228
- raw_object = mne.io.read_raw(bids_file)
296
+ bids_path = get_bids_path_from_fname(bids_file, verbose=False)
297
+ raw_object = read_raw_bids(bids_path=bids_path, verbose=False)
229
298
  eeg_data = raw_object.get_data()
230
299
 
231
300
  fs = raw_object.info["sfreq"]
@@ -521,8 +590,8 @@ class EEGDashDataset(BaseConcatDataset):
521
590
  def __init__(
522
591
  self,
523
592
  query: dict | None = None,
524
- data_dir: str | list | None = None,
525
- dataset: str | list | None = None,
593
+ cache_dir: str = "~/eegdash_cache",
594
+ dataset: str | None = None,
526
595
  description_fields: list[str] = [
527
596
  "subject",
528
597
  "session",
@@ -532,36 +601,55 @@ class EEGDashDataset(BaseConcatDataset):
532
601
  "gender",
533
602
  "sex",
534
603
  ],
535
- cache_dir: str = "~/eegdash_cache",
536
604
  s3_bucket: str | None = None,
605
+ data_dir: str | None = None,
537
606
  eeg_dash_instance=None,
607
+ records: list[dict] | None = None,
538
608
  **kwargs,
539
609
  ):
540
610
  """Create a new EEGDashDataset from a given query or local BIDS dataset directory
541
611
  and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
542
612
  instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
543
613
 
614
+
615
+ Querying Examples:
616
+ ------------------
617
+ # Find by single subject
618
+ >>> ds = EEGDashDataset(dataset="ds005505", subject="NDARCA153NKE")
619
+
620
+ # Find by a list of subjects and a specific task
621
+ >>> subjects = ["NDARCA153NKE", "NDARXT792GY8"]
622
+ >>> ds = EEGDashDataset(dataset="ds005505", subject=subjects, task="RestingState")
623
+
624
+ # Use a raw MongoDB query for advanced filtering
625
+ >>> raw_query = {"dataset": "ds005505", "subject": {"$in": subjects}}
626
+ >>> ds = EEGDashDataset(query=raw_query)
627
+
544
628
  Parameters
545
629
  ----------
546
630
  query : dict | None
547
- Optionally a dictionary that specifies the query to be executed; see
548
- EEGDash.find() for details on the query format.
549
- data_dir : str | list[str] | None
550
- Optionally a string or a list of strings specifying one or more local
551
- BIDS dataset directories from which to load the EEG data files. Exactly one
631
+ A raw MongoDB query dictionary. If provided, keyword arguments for filtering are ignored.
632
+ **kwargs : dict
633
+ Keyword arguments for filtering (e.g., `subject="X"`, `task=["T1", "T2"]`) and/or
634
+ arguments to be passed to the EEGDashBaseDataset constructor (e.g., `subject=...`).
635
+ cache_dir : str
636
+ A directory where the dataset will be cached locally.
637
+ data_dir : str | None
638
+ Optionally a string specifying a local BIDS dataset directory from which to load the EEG data files. Exactly one
552
639
  of query or data_dir must be provided.
553
- dataset : str | list[str] | None
554
- If data_dir is given, a name or list of names for for the dataset(s) to be loaded.
640
+ dataset : str | None
641
+ If data_dir is given, a name for the dataset to be loaded.
555
642
  description_fields : list[str]
556
643
  A list of fields to be extracted from the dataset records
557
644
  and included in the returned data description(s). Examples are typical
558
645
  subject metadata fields such as "subject", "session", "run", "task", etc.;
559
646
  see also data_config.description_fields for the default set of fields.
560
- cache_dir : str
561
- A directory where the dataset will be cached locally.
562
647
  s3_bucket : str | None
563
648
  An optional S3 bucket URI (e.g., "s3://mybucket") to use instead of the
564
649
  default OpenNeuro bucket for loading data files
650
+ records : list[dict] | None
651
+ Optional list of pre-fetched metadata records. If provided, the dataset is
652
+ constructed directly from these records without querying MongoDB.
565
653
  kwargs : dict
566
654
  Additional keyword arguments to be passed to the EEGDashBaseDataset
567
655
  constructor.
@@ -569,50 +657,79 @@ class EEGDashDataset(BaseConcatDataset):
569
657
  """
570
658
  self.cache_dir = cache_dir
571
659
  self.s3_bucket = s3_bucket
572
- self.eeg_dash = eeg_dash_instance or EEGDash()
573
- _owns_client = eeg_dash_instance is None
660
+ self.eeg_dash = eeg_dash_instance
661
+ _owns_client = False
662
+ if self.eeg_dash is None and records is None:
663
+ self.eeg_dash = EEGDash()
664
+ _owns_client = True
665
+
666
+ # Separate query kwargs from other kwargs passed to the BaseDataset constructor
667
+ query_kwargs = {
668
+ k: v for k, v in kwargs.items() if k in EEGDash._ALLOWED_QUERY_FIELDS
669
+ }
670
+ base_dataset_kwargs = {k: v for k, v in kwargs.items() if k not in query_kwargs}
671
+
672
+ if query and query_kwargs:
673
+ raise ValueError(
674
+ "Provide either a 'query' dictionary or keyword arguments for filtering, not both."
675
+ )
574
676
 
575
677
  try:
576
- if query:
577
- datasets = self.find_datasets(query, description_fields, **kwargs)
678
+ if records is not None:
679
+ self.records = records
680
+ datasets = [
681
+ EEGDashBaseDataset(
682
+ record,
683
+ self.cache_dir,
684
+ self.s3_bucket,
685
+ **base_dataset_kwargs,
686
+ )
687
+ for record in self.records
688
+ ]
578
689
  elif data_dir:
579
- if isinstance(data_dir, str):
690
+ # This path loads from a local directory and is not affected by DB query logic
691
+ if isinstance(data_dir, str) or isinstance(data_dir, Path):
580
692
  datasets = self.load_bids_dataset(
581
- dataset, data_dir, description_fields, s3_bucket, **kwargs
693
+ dataset=dataset,
694
+ data_dir=data_dir,
695
+ description_fields=description_fields,
696
+ s3_bucket=s3_bucket,
697
+ **base_dataset_kwargs,
582
698
  )
583
699
  else:
584
700
  assert len(data_dir) == len(dataset), (
585
- "Number of datasets and their directories must match"
701
+ "Number of datasets and directories must match"
586
702
  )
587
703
  datasets = []
588
704
  for i, _ in enumerate(data_dir):
589
705
  datasets.extend(
590
706
  self.load_bids_dataset(
591
- dataset[i],
592
- data_dir[i],
593
- description_fields,
594
- s3_bucket,
595
- **kwargs,
707
+ dataset=dataset[i],
708
+ data_dir=data_dir[i],
709
+ description_fields=description_fields,
710
+ s3_bucket=s3_bucket,
711
+ **base_dataset_kwargs,
596
712
  )
597
713
  )
714
+ elif query or query_kwargs:
715
+ # This is the DB query path that we are improving
716
+ datasets = self.find_datasets(
717
+ query=query,
718
+ description_fields=description_fields,
719
+ query_kwargs=query_kwargs,
720
+ base_dataset_kwargs=base_dataset_kwargs,
721
+ )
722
+ # We only need filesystem if we need to access S3
723
+ self.filesystem = S3FileSystem(
724
+ anon=True, client_kwargs={"region_name": "us-east-2"}
725
+ )
598
726
  else:
599
727
  raise ValueError(
600
- "Exactly one of 'query' or 'data_dir' must be provided."
728
+ "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
601
729
  )
602
730
  finally:
603
- # If we created the client, close it now that construction is done.
604
- if _owns_client:
605
- try:
606
- self.eeg_dash.close()
607
- except Exception:
608
- # Don't let close errors break construction
609
- pass
610
-
611
- self.filesystem = S3FileSystem(
612
- anon=True, client_kwargs={"region_name": "us-east-2"}
613
- )
614
-
615
- self.eeg_dash.close()
731
+ if _owns_client and self.eeg_dash is not None:
732
+ self.eeg_dash.close()
616
733
 
617
734
  super().__init__(datasets)
618
735
 
@@ -630,7 +747,11 @@ class EEGDashDataset(BaseConcatDataset):
630
747
  return None
631
748
 
632
749
  def find_datasets(
633
- self, query: dict[str, Any], description_fields: list[str], **kwargs
750
+ self,
751
+ query: dict[str, Any],
752
+ description_fields: list[str],
753
+ query_kwargs: dict,
754
+ base_dataset_kwargs: dict,
634
755
  ) -> list[EEGDashBaseDataset]:
635
756
  """Helper method to find datasets in the MongoDB collection that satisfy the
636
757
  given query and return them as a list of EEGDashBaseDataset objects.
@@ -652,7 +773,10 @@ class EEGDashDataset(BaseConcatDataset):
652
773
 
653
774
  """
654
775
  datasets: list[EEGDashBaseDataset] = []
655
- for record in self.eeg_dash.find(query):
776
+
777
+ self.records = self.eeg_dash.find(query, **query_kwargs)
778
+
779
+ for record in self.records:
656
780
  description = {}
657
781
  for field in description_fields:
658
782
  value = self.find_key_in_nested_dict(record, field)
@@ -664,15 +788,15 @@ class EEGDashDataset(BaseConcatDataset):
664
788
  self.cache_dir,
665
789
  self.s3_bucket,
666
790
  description=description,
667
- **kwargs,
791
+ **base_dataset_kwargs,
668
792
  )
669
793
  )
670
794
  return datasets
671
795
 
672
796
  def load_bids_dataset(
673
797
  self,
674
- dataset,
675
- data_dir,
798
+ dataset: str,
799
+ data_dir: str | Path,
676
800
  description_fields: list[str],
677
801
  s3_bucket: str | None = None,
678
802
  **kwargs,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: eegdash
3
- Version: 0.3.4.dev70
3
+ Version: 0.3.5.dev80
4
4
  Summary: EEG data for machine learning
5
5
  Author-email: Young Truong <dt.young112@gmail.com>, Arnaud Delorme <adelorme@gmail.com>, Aviv Dotan <avivd220@gmail.com>, Oren Shriki <oren70@gmail.com>, Bruno Aristimunha <b.aristimunha@gmail.com>
6
6
  License-Expression: GPL-3.0-only
@@ -38,6 +38,7 @@ Requires-Dist: tqdm
38
38
  Requires-Dist: xarray
39
39
  Requires-Dist: h5io>=0.2.4
40
40
  Requires-Dist: pymatreader
41
+ Requires-Dist: eeglabio
41
42
  Requires-Dist: tabulate
42
43
  Provides-Extra: tests
43
44
  Requires-Dist: pytest; extra == "tests"
@@ -1,5 +1,5 @@
1
- eegdash/__init__.py,sha256=z1uESq6VO66_4UpTGGFDW06PF_7WagRrULPFrTXrsYI,240
2
- eegdash/api.py,sha256=OqOZ27GYURSAZwTQHSs0QcW_6Mq1i_5XHP6KMcihb8A,27295
1
+ eegdash/__init__.py,sha256=K-EaG_ZHr-O4aH8SHFg7PP_rbyqlvoa3JcBdlGsXlTU,240
2
+ eegdash/api.py,sha256=KjmEVkfltLR5EwRnmnPp5rEDS5Oa6_dnprif9EVpeQs,32351
3
3
  eegdash/data_config.py,sha256=OS6ERO-jHrnEOfMJUehY7ieABdsRw_qWzOKJ4pzSfqw,1323
4
4
  eegdash/data_utils.py,sha256=_dycnPmGfTbYs7bc6edHxUn_m01dLYtp92_k44ffEoY,26475
5
5
  eegdash/dataset.py,sha256=ooLoxMFy2I8BY9gJl6ncTp_Gz-Rq0Z-o4NJyyomxLcU,2670
@@ -23,8 +23,8 @@ eegdash/features/feature_bank/dimensionality.py,sha256=j_Ds71Y1AbV2uLFQj8EuXQ4kz
23
23
  eegdash/features/feature_bank/signal.py,sha256=3Tb8z9gX7iZipxQJ9DSyy30JfdmW58kgvimSyZX74p8,3404
24
24
  eegdash/features/feature_bank/spectral.py,sha256=bNB7skusePs1gX7NOU6yRlw_Gr4UOCkO_ylkCgybzug,3319
25
25
  eegdash/features/feature_bank/utils.py,sha256=DGh-Q7-XFIittP7iBBxvsJaZrlVvuY5mw-G7q6C-PCI,1237
26
- eegdash-0.3.4.dev70.dist-info/licenses/LICENSE,sha256=asisR-xupy_NrQBFXnx6yqXeZcYWLvbAaiETl25iXT0,931
27
- eegdash-0.3.4.dev70.dist-info/METADATA,sha256=5jX-LB-ep0hcsCio2zFUKO3201B_0sa5gTbeha0I24k,10364
28
- eegdash-0.3.4.dev70.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
- eegdash-0.3.4.dev70.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
30
- eegdash-0.3.4.dev70.dist-info/RECORD,,
26
+ eegdash-0.3.5.dev80.dist-info/licenses/LICENSE,sha256=asisR-xupy_NrQBFXnx6yqXeZcYWLvbAaiETl25iXT0,931
27
+ eegdash-0.3.5.dev80.dist-info/METADATA,sha256=R0-JDW1_w2p1JJjffDbuYSlHJKGv0g7nGmyl3_AtJfY,10388
28
+ eegdash-0.3.5.dev80.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
+ eegdash-0.3.5.dev80.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
30
+ eegdash-0.3.5.dev80.dist-info/RECORD,,