eegdash 0.3.2.dev54__py3-none-any.whl → 0.3.3.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eegdash/__init__.py CHANGED
@@ -5,4 +5,4 @@ from .utils import __init__mongo_client
5
5
  __init__mongo_client()
6
6
 
7
7
  __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset"]
8
- __version__ = "0.3.2.dev54"
8
+ __version__ = "0.3.3.dev61"
eegdash/api.py CHANGED
@@ -534,6 +534,7 @@ class EEGDashDataset(BaseConcatDataset):
534
534
  ],
535
535
  cache_dir: str = "~/eegdash_cache",
536
536
  s3_bucket: str | None = None,
537
+ eeg_dash_instance=None,
537
538
  **kwargs,
538
539
  ):
539
540
  """Create a new EEGDashDataset from a given query or local BIDS dataset directory
@@ -568,29 +569,51 @@ class EEGDashDataset(BaseConcatDataset):
568
569
  """
569
570
  self.cache_dir = cache_dir
570
571
  self.s3_bucket = s3_bucket
571
- if query:
572
- datasets = self.find_datasets(query, description_fields, **kwargs)
573
- elif data_dir:
574
- if isinstance(data_dir, str):
575
- datasets = self.load_bids_dataset(
576
- dataset, data_dir, description_fields, s3_bucket
577
- )
572
+ self.eeg_dash = eeg_dash_instance or EEGDash()
573
+ _owns_client = eeg_dash_instance is None
574
+
575
+ try:
576
+ if query:
577
+ datasets = self.find_datasets(query, description_fields, **kwargs)
578
+ elif data_dir:
579
+ if isinstance(data_dir, str):
580
+ datasets = self.load_bids_dataset(
581
+ dataset, data_dir, description_fields, s3_bucket, **kwargs
582
+ )
583
+ else:
584
+ assert len(data_dir) == len(dataset), (
585
+ "Number of datasets and their directories must match"
586
+ )
587
+ datasets = []
588
+ for i, _ in enumerate(data_dir):
589
+ datasets.extend(
590
+ self.load_bids_dataset(
591
+ dataset[i],
592
+ data_dir[i],
593
+ description_fields,
594
+ s3_bucket,
595
+ **kwargs,
596
+ )
597
+ )
578
598
  else:
579
- assert len(data_dir) == len(dataset), (
580
- "Number of datasets and their directories must match"
599
+ raise ValueError(
600
+ "Exactly one of 'query' or 'data_dir' must be provided."
581
601
  )
582
- datasets = []
583
- for i, _ in enumerate(data_dir):
584
- datasets.extend(
585
- self.load_bids_dataset(
586
- dataset[i], data_dir[i], description_fields, s3_bucket
587
- )
588
- )
602
+ finally:
603
+ # If we created the client, close it now that construction is done.
604
+ if _owns_client:
605
+ try:
606
+ self.eeg_dash.close()
607
+ except Exception:
608
+ # Don't let close errors break construction
609
+ pass
589
610
 
590
611
  self.filesystem = S3FileSystem(
591
612
  anon=True, client_kwargs={"region_name": "us-east-2"}
592
613
  )
593
614
 
615
+ self.eeg_dash.close()
616
+
594
617
  super().__init__(datasets)
595
618
 
596
619
  def find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
@@ -628,27 +651,23 @@ class EEGDashDataset(BaseConcatDataset):
628
651
  A list of EEGDashBaseDataset objects that match the query.
629
652
 
630
653
  """
631
- eeg_dash_instance = EEGDash()
632
- try:
633
- datasets = []
634
- for record in eeg_dash_instance.find(query):
635
- description = {}
636
- for field in description_fields:
637
- value = self.find_key_in_nested_dict(record, field)
638
- if value is not None:
639
- description[field] = value
640
- datasets.append(
641
- EEGDashBaseDataset(
642
- record,
643
- self.cache_dir,
644
- self.s3_bucket,
645
- description=description,
646
- **kwargs,
647
- )
654
+ datasets: list[EEGDashBaseDataset] = []
655
+ for record in self.eeg_dash.find(query):
656
+ description = {}
657
+ for field in description_fields:
658
+ value = self.find_key_in_nested_dict(record, field)
659
+ if value is not None:
660
+ description[field] = value
661
+ datasets.append(
662
+ EEGDashBaseDataset(
663
+ record,
664
+ self.cache_dir,
665
+ self.s3_bucket,
666
+ description=description,
667
+ **kwargs,
648
668
  )
649
- return datasets
650
- finally:
651
- eeg_dash_instance.close()
669
+ )
670
+ return datasets
652
671
 
653
672
  def load_bids_dataset(
654
673
  self,
@@ -676,36 +695,28 @@ class EEGDashDataset(BaseConcatDataset):
676
695
  data_dir=data_dir,
677
696
  dataset=dataset,
678
697
  )
679
- eeg_dash_instance = EEGDash()
680
- try:
681
- datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
682
- delayed(self.get_base_dataset_from_bids_file)(
683
- bids_dataset=bids_dataset,
684
- bids_file=bids_file,
685
- eeg_dash_instance=eeg_dash_instance,
686
- s3_bucket=s3_bucket,
687
- description_fields=description_fields,
688
- )
689
- for bids_file in bids_dataset.get_files()
698
+ datasets = Parallel(n_jobs=-1, prefer="threads", verbose=1)(
699
+ delayed(self.get_base_dataset_from_bids_file)(
700
+ bids_dataset=bids_dataset,
701
+ bids_file=bids_file,
702
+ s3_bucket=s3_bucket,
703
+ description_fields=description_fields,
704
+ **kwargs,
690
705
  )
691
- return datasets
692
- finally:
693
- eeg_dash_instance.close()
706
+ for bids_file in bids_dataset.get_files()
707
+ )
708
+ return datasets
694
709
 
695
710
  def get_base_dataset_from_bids_file(
696
711
  self,
697
- bids_dataset: EEGBIDSDataset,
712
+ bids_dataset: "EEGBIDSDataset",
698
713
  bids_file: str,
699
- eeg_dash_instance: EEGDash,
700
714
  s3_bucket: str | None,
701
715
  description_fields: list[str],
702
- ) -> EEGDashBaseDataset:
703
- """Instantiate a single EEGDashBaseDataset given a local BIDS file. Note
704
- this does not actually load the data from disk, but will access the metadata.
705
- """
706
- record = eeg_dash_instance.load_eeg_attrs_from_bids_file(
707
- bids_dataset, bids_file
708
- )
716
+ **kwargs,
717
+ ) -> "EEGDashBaseDataset":
718
+ """Instantiate a single EEGDashBaseDataset given a local BIDS file (metadata only)."""
719
+ record = self.eeg_dash.load_eeg_attrs_from_bids_file(bids_dataset, bids_file)
709
720
  description = {}
710
721
  for field in description_fields:
711
722
  value = self.find_key_in_nested_dict(record, field)
@@ -716,4 +727,5 @@ class EEGDashDataset(BaseConcatDataset):
716
727
  self.cache_dir,
717
728
  s3_bucket,
718
729
  description=description,
730
+ **kwargs,
719
731
  )
eegdash/data_utils.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
10
  import pandas as pd
11
11
  import s3fs
12
12
  from bids import BIDSLayout
13
+ from fsspec.callbacks import TqdmCallback
13
14
  from joblib import Parallel, delayed
14
15
  from mne._fiff.utils import _read_segments_file
15
16
  from mne.io import BaseRaw
@@ -98,8 +99,27 @@ class EEGDashBaseDataset(BaseDataset):
98
99
  self.s3file = re.sub(r"(^|/)ds\d{6}/", r"\1", self.s3file, count=1)
99
100
 
100
101
  self.filecache.parent.mkdir(parents=True, exist_ok=True)
102
+ info = filesystem.info(self.s3file)
103
+ size = info.get("size") or info.get("Size")
104
+
105
+ callback = TqdmCallback(
106
+ size=size,
107
+ tqdm_kwargs=dict(
108
+ desc=f"Downloading {Path(self.s3file).name}",
109
+ unit="B",
110
+ unit_scale=True,
111
+ unit_divisor=1024,
112
+ dynamic_ncols=True,
113
+ leave=True,
114
+ mininterval=0.2,
115
+ smoothing=0.1,
116
+ miniters=1,
117
+ bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
118
+ "[{elapsed}<{remaining}, {rate_fmt}]",
119
+ ),
120
+ )
121
+ filesystem.get(self.s3file, self.filecache, callback=callback)
101
122
 
102
- filesystem.download(self.s3file, self.filecache)
103
123
  self.filenames = [self.filecache]
104
124
 
105
125
  def _download_dependencies(self) -> None:
@@ -119,7 +139,26 @@ class EEGDashBaseDataset(BaseDataset):
119
139
  # in the case of the competition.
120
140
  if not filepath.exists():
121
141
  filepath.parent.mkdir(parents=True, exist_ok=True)
122
- filesystem.download(s3path, filepath)
142
+ info = filesystem.info(s3path)
143
+ size = info.get("size") or info.get("Size")
144
+
145
+ callback = TqdmCallback(
146
+ size=size,
147
+ tqdm_kwargs=dict(
148
+ desc=f"Downloading {Path(s3path).name}",
149
+ unit="B",
150
+ unit_scale=True,
151
+ unit_divisor=1024,
152
+ dynamic_ncols=True,
153
+ leave=True,
154
+ mininterval=0.2,
155
+ smoothing=0.1,
156
+ miniters=1,
157
+ bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} "
158
+ "[{elapsed}<{remaining}, {rate_fmt}]",
159
+ ),
160
+ )
161
+ filesystem.get(s3path, filepath, callback=callback)
123
162
 
124
163
  def get_raw_bids_args(self) -> dict[str, Any]:
125
164
  """Helper to restrict the metadata record to the fields needed to locate a BIDS
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: eegdash
3
- Version: 0.3.2.dev54
3
+ Version: 0.3.3.dev61
4
4
  Summary: EEG data for machine learning
5
5
  Author-email: Young Truong <dt.young112@gmail.com>, Arnaud Delorme <adelorme@gmail.com>, Aviv Dotan <avivd220@gmail.com>, Oren Shriki <oren70@gmail.com>, Bruno Aristimunha <b.aristimunha@gmail.com>
6
6
  License: GNU General Public License
@@ -60,6 +60,8 @@ Requires-Dist: s3fs
60
60
  Requires-Dist: scipy
61
61
  Requires-Dist: tqdm
62
62
  Requires-Dist: xarray
63
+ Requires-Dist: h5io>=0.2.4
64
+ Requires-Dist: pymatreader
63
65
  Provides-Extra: tests
64
66
  Requires-Dist: pytest; extra == "tests"
65
67
  Requires-Dist: pytest-cov; extra == "tests"
@@ -1,7 +1,7 @@
1
- eegdash/__init__.py,sha256=52uzfdLjujAnXKER5pn7ft52HHmDirUOIeqT-68J7Sk,238
2
- eegdash/api.py,sha256=lYCILa69Y_RRd4_13x1Ge77FDGswnEG2DEfdPzMygAY,26865
1
+ eegdash/__init__.py,sha256=PQbCZHVxYTBpYOu1wIrGntJNKXt7bK0qh5kq8mb3ufo,238
2
+ eegdash/api.py,sha256=OqOZ27GYURSAZwTQHSs0QcW_6Mq1i_5XHP6KMcihb8A,27295
3
3
  eegdash/data_config.py,sha256=OS6ERO-jHrnEOfMJUehY7ieABdsRw_qWzOKJ4pzSfqw,1323
4
- eegdash/data_utils.py,sha256=h1FLCTjRj2JDn6IFnVgWrogDYnHFR4mcuDF23KNtLZI,24530
4
+ eegdash/data_utils.py,sha256=8Jb_94uVbdknNPpx3GBl4dCDYUIJNzl3zkLwbfH90N4,26052
5
5
  eegdash/dataset.py,sha256=GVELU-eXq9AQDzOeg6Lkykd-Pctyn42e5UEcAV0Go4s,2348
6
6
  eegdash/mongodb.py,sha256=GD3WgA253oFgpzOHrYaj4P1mRjNtDMT5Oj4kVvHswjI,2006
7
7
  eegdash/preprocessing.py,sha256=7S_TTRKPKEk47tTnh2D6WExBt4cctAMxUxGDjJqq5lU,2221
@@ -21,8 +21,8 @@ eegdash/features/feature_bank/dimensionality.py,sha256=j_Ds71Y1AbV2uLFQj8EuXQ4kz
21
21
  eegdash/features/feature_bank/signal.py,sha256=3Tb8z9gX7iZipxQJ9DSyy30JfdmW58kgvimSyZX74p8,3404
22
22
  eegdash/features/feature_bank/spectral.py,sha256=bNB7skusePs1gX7NOU6yRlw_Gr4UOCkO_ylkCgybzug,3319
23
23
  eegdash/features/feature_bank/utils.py,sha256=DGh-Q7-XFIittP7iBBxvsJaZrlVvuY5mw-G7q6C-PCI,1237
24
- eegdash-0.3.2.dev54.dist-info/licenses/LICENSE,sha256=asisR-xupy_NrQBFXnx6yqXeZcYWLvbAaiETl25iXT0,931
25
- eegdash-0.3.2.dev54.dist-info/METADATA,sha256=xGiRhU968Kcvz5Wj6RiavV6yMG72G5eSpnXsxvFX9hU,11429
26
- eegdash-0.3.2.dev54.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- eegdash-0.3.2.dev54.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
28
- eegdash-0.3.2.dev54.dist-info/RECORD,,
24
+ eegdash-0.3.3.dev61.dist-info/licenses/LICENSE,sha256=asisR-xupy_NrQBFXnx6yqXeZcYWLvbAaiETl25iXT0,931
25
+ eegdash-0.3.3.dev61.dist-info/METADATA,sha256=W_KFw8Hn3ekBmHKfVPzX-WS3Ylv3Xkn_yWbauePYbOY,11483
26
+ eegdash-0.3.3.dev61.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ eegdash-0.3.3.dev61.dist-info/top_level.txt,sha256=zavO69HQ6MyZM0aQMR2zUS6TAFc7bnN5GEpDpOpFZzU,8
28
+ eegdash-0.3.3.dev61.dist-info/RECORD,,