eegdash 0.3.6.dev182011805__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

eegdash/data_utils.py CHANGED
@@ -1,9 +1,11 @@
1
+ import io
1
2
  import json
2
3
  import logging
3
4
  import os
4
5
  import re
5
6
  import traceback
6
7
  import warnings
8
+ from contextlib import redirect_stderr
7
9
  from pathlib import Path
8
10
  from typing import Any
9
11
 
@@ -21,6 +23,8 @@ from mne_bids import BIDSPath
21
23
 
22
24
  from braindecode.datasets import BaseDataset
23
25
 
26
+ from .paths import get_default_cache_dir
27
+
24
28
  logger = logging.getLogger("eegdash")
25
29
 
26
30
 
@@ -57,7 +61,7 @@ class EEGDashBaseDataset(BaseDataset):
57
61
  super().__init__(None, **kwargs)
58
62
  self.record = record
59
63
  self.cache_dir = Path(cache_dir)
60
- self.bids_kwargs = self.get_raw_bids_args()
64
+ self.bids_kwargs = self._get_raw_bids_args()
61
65
 
62
66
  if s3_bucket:
63
67
  self.s3_bucket = s3_bucket
@@ -66,8 +70,27 @@ class EEGDashBaseDataset(BaseDataset):
66
70
  self.s3_bucket = self._AWS_BUCKET
67
71
  self.s3_open_neuro = True
68
72
 
69
- self.filecache = self.cache_dir / record["bidspath"]
70
- self.bids_root = self.cache_dir / record["dataset"]
73
+ # Compute a dataset folder name under cache_dir that encodes preprocessing
74
+ # (e.g., bdf, mini) to avoid overlapping with the original dataset cache.
75
+ self.dataset_folder = record.get("dataset", "")
76
+ if s3_bucket:
77
+ suffixes: list[str] = []
78
+ bucket_lower = str(s3_bucket).lower()
79
+ if "bdf" in bucket_lower:
80
+ suffixes.append("bdf")
81
+ if "mini" in bucket_lower:
82
+ suffixes.append("mini")
83
+ if suffixes:
84
+ self.dataset_folder = f"{self.dataset_folder}-{'-'.join(suffixes)}"
85
+
86
+ # Place files under the dataset-specific folder (with suffix if any)
87
+ rel = Path(record["bidspath"]) # usually starts with dataset id
88
+ if rel.parts and rel.parts[0] == record.get("dataset"):
89
+ rel = Path(self.dataset_folder, *rel.parts[1:])
90
+ else:
91
+ rel = Path(self.dataset_folder) / rel
92
+ self.filecache = self.cache_dir / rel
93
+ self.bids_root = self.cache_dir / self.dataset_folder
71
94
  self.bidspath = BIDSPath(
72
95
  root=self.bids_root,
73
96
  datatype="eeg",
@@ -75,7 +98,7 @@ class EEGDashBaseDataset(BaseDataset):
75
98
  **self.bids_kwargs,
76
99
  )
77
100
 
78
- self.s3file = self.get_s3path(record["bidspath"])
101
+ self.s3file = self._get_s3path(record["bidspath"])
79
102
  self.bids_dependencies = record["bidsdependencies"]
80
103
  # Temporary fix for BIDS dependencies path
81
104
  # just to release to the competition
@@ -87,7 +110,7 @@ class EEGDashBaseDataset(BaseDataset):
87
110
 
88
111
  self._raw = None
89
112
 
90
- def get_s3path(self, filepath: str) -> str:
113
+ def _get_s3path(self, filepath: str) -> str:
91
114
  """Helper to form an AWS S3 URI for the given relative filepath."""
92
115
  return f"{self.s3_bucket}/{filepath}"
93
116
 
@@ -141,14 +164,22 @@ class EEGDashBaseDataset(BaseDataset):
141
164
  if dep.endswith(".set"):
142
165
  dep = dep[:-4] + ".bdf"
143
166
 
144
- s3path = self.get_s3path(dep)
167
+ s3path = self._get_s3path(dep)
145
168
  if not self.s3_open_neuro:
146
169
  dep = self.bids_dependencies_original[i]
147
170
 
148
- filepath = self.cache_dir / dep
171
+ dep_path = Path(dep)
172
+ if dep_path.parts and dep_path.parts[0] == self.record.get("dataset"):
173
+ dep_local = Path(self.dataset_folder, *dep_path.parts[1:])
174
+ else:
175
+ dep_local = Path(self.dataset_folder) / dep_path
176
+ filepath = self.cache_dir / dep_local
149
177
  if not self.s3_open_neuro:
178
+ if filepath.suffix == ".set":
179
+ filepath = filepath.with_suffix(".bdf")
150
180
  if self.filecache.suffix == ".set":
151
181
  self.filecache = self.filecache.with_suffix(".bdf")
182
+
152
183
  # here, we download the dependency and it is fine
153
184
  # in the case of the competition.
154
185
  if not filepath.exists():
@@ -174,15 +205,21 @@ class EEGDashBaseDataset(BaseDataset):
174
205
  )
175
206
  filesystem.get(s3path, filepath, callback=callback)
176
207
 
177
- def get_raw_bids_args(self) -> dict[str, Any]:
208
+ def _get_raw_bids_args(self) -> dict[str, Any]:
178
209
  """Helper to restrict the metadata record to the fields needed to locate a BIDS
179
210
  recording.
180
211
  """
181
212
  desired_fields = ["subject", "session", "task", "run"]
182
213
  return {k: self.record[k] for k in desired_fields if self.record[k]}
183
214
 
184
- def check_and_get_raw(self) -> None:
215
+ def _ensure_raw(self) -> None:
185
216
  """Download the S3 file and BIDS dependencies if not already cached."""
217
+ # TO-DO: remove this once is fixed on the our side
218
+ # for the competition
219
+ if not self.s3_open_neuro:
220
+ self.bidspath = self.bidspath.update(extension=".bdf")
221
+ self.filecache = self.filecache.with_suffix(".bdf")
222
+
186
223
  if not os.path.exists(self.filecache): # not preload
187
224
  if self.bids_dependencies:
188
225
  self._download_dependencies()
@@ -191,14 +228,50 @@ class EEGDashBaseDataset(BaseDataset):
191
228
  # capturing any warnings
192
229
  # to-do: remove this once is fixed on the mne-bids side.
193
230
  with warnings.catch_warnings(record=True) as w:
231
+ # Ensure all warnings are captured into 'w' and not shown to users
232
+ warnings.simplefilter("always")
194
233
  try:
195
- # TO-DO: remove this once is fixed on the our side
196
- if not self.s3_open_neuro:
197
- self.bidspath = self.bidspath.update(extension=".bdf")
198
-
199
- self._raw = mne_bids.read_raw_bids(
200
- bids_path=self.bidspath, verbose="ERROR"
201
- )
234
+ # mne-bids emits RuntimeWarnings to stderr; silence stderr during read
235
+ _stderr_buffer = io.StringIO()
236
+ with redirect_stderr(_stderr_buffer):
237
+ self._raw = mne_bids.read_raw_bids(
238
+ bids_path=self.bidspath, verbose="ERROR"
239
+ )
240
+ # Parse unmapped participants.tsv fields reported by mne-bids and
241
+ # inject them into Raw.info and the dataset description generically.
242
+ extras = self._extract_unmapped_participants_from_warnings(w)
243
+ if extras:
244
+ # 1) Attach to Raw.info under subject_info.participants_extras
245
+ try:
246
+ subject_info = self._raw.info.get("subject_info") or {}
247
+ if not isinstance(subject_info, dict):
248
+ subject_info = {}
249
+ pe = subject_info.get("participants_extras") or {}
250
+ if not isinstance(pe, dict):
251
+ pe = {}
252
+ # Merge without overwriting
253
+ for k, v in extras.items():
254
+ pe.setdefault(k, v)
255
+ subject_info["participants_extras"] = pe
256
+ self._raw.info["subject_info"] = subject_info
257
+ except Exception:
258
+ # Non-fatal; continue
259
+ pass
260
+
261
+ # 2) Also add to this dataset's description, if possible, so
262
+ # targets can be selected later without naming specifics.
263
+ try:
264
+ import pandas as _pd # local import to avoid top-level cost
265
+
266
+ if isinstance(self.description, dict):
267
+ for k, v in extras.items():
268
+ self.description.setdefault(k, v)
269
+ elif isinstance(self.description, _pd.Series):
270
+ for k, v in extras.items():
271
+ if k not in self.description.index:
272
+ self.description.loc[k] = v
273
+ except Exception:
274
+ pass
202
275
  except Exception as e:
203
276
  logger.error(
204
277
  f"Error while reading BIDS file: {self.bidspath}\n"
@@ -208,10 +281,60 @@ class EEGDashBaseDataset(BaseDataset):
208
281
  logger.error(f"Exception: {e}")
209
282
  logger.error(traceback.format_exc())
210
283
  raise e
211
- for warning in w:
212
- logger.warning(
213
- f"Warning while reading BIDS file: {warning.message}"
214
- )
284
+ # Filter noisy mapping notices from mne-bids; surface others
285
+ for captured_warning in w:
286
+ try:
287
+ msg = str(captured_warning.message)
288
+ except Exception:
289
+ continue
290
+ # Suppress verbose participants mapping messages
291
+ if "Unable to map the following column" in msg and "MNE" in msg:
292
+ logger.debug(
293
+ "Suppressed mne-bids mapping warning while reading BIDS file: %s",
294
+ msg,
295
+ )
296
+ continue
297
+ logger.warning("Warning while reading BIDS file: %s", msg)
298
+
299
+ def _extract_unmapped_participants_from_warnings(
300
+ self, warnings_list: list[Any]
301
+ ) -> dict[str, Any]:
302
+ """Scan captured warnings from mne-bids and extract unmapped participants.tsv
303
+ entries in a generic way.
304
+
305
+ Optionally, the column name can carry a note in parentheses that we ignore
306
+ for key/value extraction. Returns a mapping of column name -> raw value.
307
+ """
308
+ extras: dict[str, Any] = {}
309
+ header = "Unable to map the following column(s) to MNE:"
310
+ for wr in warnings_list:
311
+ try:
312
+ msg = str(wr.message)
313
+ except Exception:
314
+ continue
315
+ if header not in msg:
316
+ continue
317
+ lines = msg.splitlines()
318
+ # Find the header line, then parse subsequent lines as entries
319
+ try:
320
+ idx = next(i for i, ln in enumerate(lines) if header in ln)
321
+ except StopIteration:
322
+ idx = -1
323
+ for line in lines[idx + 1 :]:
324
+ line = line.strip()
325
+ if not line:
326
+ continue
327
+ # Pattern: <col>(optional note): <value>
328
+ # Examples: "gender: F", "Ethnicity: Indian", "foo (ignored): bar"
329
+ m = re.match(r"^([^:]+?)(?:\s*\([^)]*\))?\s*:\s*(.*)$", line)
330
+ if not m:
331
+ continue
332
+ col = m.group(1).strip()
333
+ val = m.group(2).strip()
334
+ # Keep original column names as provided to stay agnostic
335
+ if col and col not in extras:
336
+ extras[col] = val
337
+ return extras
215
338
 
216
339
  # === BaseDataset and PyTorch Dataset interface ===
217
340
 
@@ -230,11 +353,16 @@ class EEGDashBaseDataset(BaseDataset):
230
353
  def __len__(self) -> int:
231
354
  """Return the number of samples in the dataset."""
232
355
  if self._raw is None:
233
- # FIXME: this is a bit strange and should definitely not change as a side effect
234
- # of accessing the data (which it will, since ntimes is the actual length but rounded down)
235
- return int(self.record["ntimes"] * self.record["sampling_frequency"])
236
- else:
237
- return len(self._raw)
356
+ if (
357
+ self.record["ntimes"] is None
358
+ or self.record["sampling_frequency"] is None
359
+ ):
360
+ self._ensure_raw()
361
+ else:
362
+ # FIXME: this is a bit strange and should definitely not change as a side effect
363
+ # of accessing the data (which it will, since ntimes is the actual length but rounded down)
364
+ return int(self.record["ntimes"] * self.record["sampling_frequency"])
365
+ return len(self._raw)
238
366
 
239
367
  @property
240
368
  def raw(self):
@@ -242,7 +370,7 @@ class EEGDashBaseDataset(BaseDataset):
242
370
  retrieval if not yet done so.
243
371
  """
244
372
  if self._raw is None:
245
- self.check_and_get_raw()
373
+ self._ensure_raw()
246
374
  return self._raw
247
375
 
248
376
  @raw.setter
@@ -284,7 +412,7 @@ class EEGDashBaseRaw(BaseRaw):
284
412
  metadata: dict[str, Any],
285
413
  preload: bool = False,
286
414
  *,
287
- cache_dir: str = "~/eegdash_cache",
415
+ cache_dir: str | None = None,
288
416
  bids_dependencies: list[str] = [],
289
417
  verbose: Any = None,
290
418
  ):
@@ -300,8 +428,9 @@ class EEGDashBaseRaw(BaseRaw):
300
428
  chtype = "eog"
301
429
  ch_types.append(chtype)
302
430
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
303
- self.s3file = self.get_s3path(input_fname)
304
- self.cache_dir = Path(cache_dir)
431
+
432
+ self.s3file = self._get_s3path(input_fname)
433
+ self.cache_dir = Path(cache_dir) if cache_dir else get_default_cache_dir()
305
434
  self.filecache = self.cache_dir / input_fname
306
435
  self.bids_dependencies = bids_dependencies
307
436
 
@@ -317,7 +446,7 @@ class EEGDashBaseRaw(BaseRaw):
317
446
  verbose=verbose,
318
447
  )
319
448
 
320
- def get_s3path(self, filepath):
449
+ def _get_s3path(self, filepath):
321
450
  return f"{self._AWS_BUCKET}/{filepath}"
322
451
 
323
452
  def _download_s3(self) -> None:
@@ -333,7 +462,7 @@ class EEGDashBaseRaw(BaseRaw):
333
462
  anon=True, client_kwargs={"region_name": "us-east-2"}
334
463
  )
335
464
  for dep in self.bids_dependencies:
336
- s3path = self.get_s3path(dep)
465
+ s3path = self._get_s3path(dep)
337
466
  filepath = self.cache_dir / dep
338
467
  if not filepath.exists():
339
468
  filepath.parent.mkdir(parents=True, exist_ok=True)
@@ -394,11 +523,17 @@ class EEGBIDSDataset:
394
523
  raise ValueError("data_dir must be specified and must exist")
395
524
  self.bidsdir = Path(data_dir)
396
525
  self.dataset = dataset
397
- assert str(self.bidsdir).endswith(self.dataset)
526
+ # Accept exact dataset folder or a variant with informative suffixes
527
+ # (e.g., dsXXXXX-bdf, dsXXXXX-bdf-mini) to avoid collisions.
528
+ dir_name = self.bidsdir.name
529
+ if not (dir_name == self.dataset or dir_name.startswith(self.dataset + "-")):
530
+ raise AssertionError(
531
+ f"BIDS directory '{dir_name}' does not correspond to dataset '{self.dataset}'"
532
+ )
398
533
  self.layout = BIDSLayout(data_dir)
399
534
 
400
535
  # get all recording files in the bids directory
401
- self.files = self.get_recordings(self.layout)
536
+ self.files = self._get_recordings(self.layout)
402
537
  assert len(self.files) > 0, ValueError(
403
538
  "Unable to construct EEG dataset. No EEG recordings found."
404
539
  )
@@ -408,7 +543,7 @@ class EEGBIDSDataset:
408
543
  """Check if the dataset is EEG."""
409
544
  return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
410
545
 
411
- def get_recordings(self, layout: BIDSLayout) -> list[str]:
546
+ def _get_recordings(self, layout: BIDSLayout) -> list[str]:
412
547
  """Get a list of all EEG recording files in the BIDS layout."""
413
548
  files = []
414
549
  for ext, exts in self.RAW_EXTENSIONS.items():
@@ -417,12 +552,12 @@ class EEGBIDSDataset:
417
552
  break
418
553
  return files
419
554
 
420
- def get_relative_bidspath(self, filename: str) -> str:
555
+ def _get_relative_bidspath(self, filename: str) -> str:
421
556
  """Make the given file path relative to the BIDS directory."""
422
557
  bids_parent_dir = self.bidsdir.parent.absolute()
423
558
  return str(Path(filename).relative_to(bids_parent_dir))
424
559
 
425
- def get_property_from_filename(self, property: str, filename: str) -> str:
560
+ def _get_property_from_filename(self, property: str, filename: str) -> str:
426
561
  """Parse a property out of a BIDS-compliant filename. Returns an empty string
427
562
  if not found.
428
563
  """
@@ -434,7 +569,7 @@ class EEGBIDSDataset:
434
569
  lookup = re.search(rf"{property}-(.*?)[_\/]", filename)
435
570
  return lookup.group(1) if lookup else ""
436
571
 
437
- def merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
572
+ def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
438
573
  """Internal helper to merge list of json files found by get_bids_file_inheritance,
439
574
  expecting the order (from left to right) is from lowest
440
575
  level to highest level, and return a merged dictionary
@@ -445,7 +580,7 @@ class EEGBIDSDataset:
445
580
  json_dict.update(json.load(open(f))) # FIXME: should close file
446
581
  return json_dict
447
582
 
448
- def get_bids_file_inheritance(
583
+ def _get_bids_file_inheritance(
449
584
  self, path: str | Path, basename: str, extension: str
450
585
  ) -> list[Path]:
451
586
  """Get all file paths that apply to the basename file in the specified directory
@@ -492,7 +627,7 @@ class EEGBIDSDataset:
492
627
  else:
493
628
  # call get_bids_file_inheritance recursively with parent directory
494
629
  bids_files.extend(
495
- self.get_bids_file_inheritance(path.parent, basename, extension)
630
+ self._get_bids_file_inheritance(path.parent, basename, extension)
496
631
  )
497
632
  return bids_files
498
633
 
@@ -523,12 +658,12 @@ class EEGBIDSDataset:
523
658
  path, filename = os.path.split(filepath)
524
659
  basename = filename[: filename.rfind("_")]
525
660
  # metadata files
526
- meta_files = self.get_bids_file_inheritance(
661
+ meta_files = self._get_bids_file_inheritance(
527
662
  path, basename, metadata_file_extension
528
663
  )
529
664
  return meta_files
530
665
 
531
- def scan_directory(self, directory: str, extension: str) -> list[Path]:
666
+ def _scan_directory(self, directory: str, extension: str) -> list[Path]:
532
667
  """Return a list of file paths that end with the given extension in the specified
533
668
  directory. Ignores certain special directories like .git, .datalad, derivatives,
534
669
  and code.
@@ -545,7 +680,7 @@ class EEGBIDSDataset:
545
680
  result_files.append(entry.path) # Add directory to scan later
546
681
  return result_files
547
682
 
548
- def get_files_with_extension_parallel(
683
+ def _get_files_with_extension_parallel(
549
684
  self, directory: str, extension: str = ".set", max_workers: int = -1
550
685
  ) -> list[Path]:
551
686
  """Efficiently scan a directory and its subdirectories for files that end with
@@ -577,7 +712,7 @@ class EEGBIDSDataset:
577
712
  )
578
713
  # Run the scan_directory function in parallel across directories
579
714
  results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
580
- delayed(self.scan_directory)(d, extension) for d in dirs_to_scan
715
+ delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
581
716
  )
582
717
 
583
718
  # Reset the directories to scan and process the results
@@ -682,7 +817,7 @@ class EEGBIDSDataset:
682
817
  def num_times(self, data_filepath: str) -> int:
683
818
  """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
684
819
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
685
- eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
820
+ eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
686
821
  return int(
687
822
  eeg_json_dict["SamplingFrequency"] * eeg_json_dict["RecordingDuration"]
688
823
  )
@@ -705,7 +840,7 @@ class EEGBIDSDataset:
705
840
  def eeg_json(self, data_filepath: str) -> dict[str, Any]:
706
841
  """Get BIDS eeg.json metadata for the given data file path."""
707
842
  eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
708
- eeg_json_dict = self.merge_json_inheritance(eeg_jsons)
843
+ eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
709
844
  return eeg_json_dict
710
845
 
711
846
  def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
@@ -0,0 +1,4 @@
1
+ from .dataset import EEGChallengeDataset
2
+ from .registry import register_openneuro_datasets
3
+
4
+ __all__ = ["EEGChallengeDataset", "register_openneuro_datasets"]
@@ -0,0 +1,161 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from mne.utils import warn
5
+
6
+ from ..api import EEGDashDataset
7
+ from ..bids_eeg_metadata import build_query_from_kwargs
8
+ from ..const import RELEASE_TO_OPENNEURO_DATASET_MAP, SUBJECT_MINI_RELEASE_MAP
9
+ from .registry import register_openneuro_datasets
10
+
11
+ logger = logging.getLogger("eegdash")
12
+
13
+
14
+ class EEGChallengeDataset(EEGDashDataset):
15
+ def __init__(
16
+ self,
17
+ release: str,
18
+ cache_dir: str,
19
+ mini: bool = True,
20
+ query: dict | None = None,
21
+ s3_bucket: str | None = "s3://nmdatasets/NeurIPS25",
22
+ **kwargs,
23
+ ):
24
+ """Create a new EEGDashDataset from a given query or local BIDS dataset directory
25
+ and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
26
+ instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
27
+
28
+ Parameters
29
+ ----------
30
+ release: str
31
+ Release name. Can be one of ["R1", ..., "R11"]
32
+ mini: bool, default True
33
+ Whether to use the mini-release version of the dataset. It is recommended
34
+ to use the mini version for faster training and evaluation.
35
+ query : dict | None
36
+ Optionally a dictionary that specifies a query to be executed,
37
+ in addition to the dataset (automatically inferred from the release argument).
38
+ See EEGDash.find() for details on the query format.
39
+ cache_dir : str
40
+ A directory where the dataset will be cached locally.
41
+ s3_bucket : str | None
42
+ An optional S3 bucket URI to use instead of the
43
+ default OpenNeuro bucket for loading data files.
44
+ kwargs : dict
45
+ Additional keyword arguments to be passed to the EEGDashDataset
46
+ constructor.
47
+
48
+ """
49
+ self.release = release
50
+ self.mini = mini
51
+
52
+ if release not in RELEASE_TO_OPENNEURO_DATASET_MAP:
53
+ raise ValueError(
54
+ f"Unknown release: {release}, expected one of {list(RELEASE_TO_OPENNEURO_DATASET_MAP.keys())}"
55
+ )
56
+
57
+ dataset_parameters = []
58
+ if isinstance(release, str):
59
+ dataset_parameters.append(RELEASE_TO_OPENNEURO_DATASET_MAP[release])
60
+ else:
61
+ raise ValueError(
62
+ f"Unknown release type: {type(release)}, the expected type is str."
63
+ )
64
+
65
+ if query and "dataset" in query:
66
+ raise ValueError(
67
+ "Query using the parameters `dataset` with the class EEGChallengeDataset is not possible."
68
+ "Please use the release argument instead, or the object EEGDashDataset instead."
69
+ )
70
+
71
+ if self.mini:
72
+ # When using the mini release, restrict subjects to the predefined subset.
73
+ # If the user specifies subject(s), ensure they all belong to the mini subset;
74
+ # otherwise, default to the full mini subject list for this release.
75
+
76
+ allowed_subjects = set(SUBJECT_MINI_RELEASE_MAP[release])
77
+
78
+ # Normalize potential 'subjects' -> 'subject' for convenience
79
+ if "subjects" in kwargs and "subject" not in kwargs:
80
+ kwargs["subject"] = kwargs.pop("subjects")
81
+
82
+ # Collect user-requested subjects from kwargs/query. We canonicalize
83
+ # kwargs via build_query_from_kwargs to leverage existing validation,
84
+ # and support Mongo-style {"$in": [...]} shapes from a raw query.
85
+ requested_subjects: list[str] = []
86
+
87
+ # From kwargs
88
+ if "subject" in kwargs and kwargs["subject"] is not None:
89
+ # Use the shared query builder to normalize scalars/lists
90
+ built = build_query_from_kwargs(subject=kwargs["subject"])
91
+ s_val = built.get("subject")
92
+ if isinstance(s_val, dict) and "$in" in s_val:
93
+ requested_subjects.extend(list(s_val["$in"]))
94
+ elif s_val is not None:
95
+ requested_subjects.append(s_val) # type: ignore[arg-type]
96
+
97
+ # From query (top-level only)
98
+ if query and isinstance(query, dict) and "subject" in query:
99
+ qval = query["subject"]
100
+ if isinstance(qval, dict) and "$in" in qval:
101
+ requested_subjects.extend(list(qval["$in"]))
102
+ elif isinstance(qval, (list, tuple, set)):
103
+ requested_subjects.extend(list(qval))
104
+ elif qval is not None:
105
+ requested_subjects.append(qval)
106
+
107
+ # Validate if any subjects were explicitly requested
108
+ if requested_subjects:
109
+ invalid = sorted(
110
+ {s for s in requested_subjects if s not in allowed_subjects}
111
+ )
112
+ if invalid:
113
+ raise ValueError(
114
+ "Some requested subject(s) are not part of the mini release for "
115
+ f"{release}: {invalid}. Allowed subjects: {sorted(allowed_subjects)}"
116
+ )
117
+ # Do not override user selection; keep their (validated) subjects as-is.
118
+ else:
119
+ # No subject specified by the user: default to the full mini subset
120
+ kwargs["subject"] = sorted(allowed_subjects)
121
+
122
+ s3_bucket = f"{s3_bucket}/{release}_mini_L100_bdf"
123
+ else:
124
+ s3_bucket = f"{s3_bucket}/{release}_L100_bdf"
125
+
126
+ warn(
127
+ "\n\n"
128
+ "[EEGChallengeDataset] EEG 2025 Competition Data Notice:\n"
129
+ "-------------------------------------------------------\n"
130
+ "This object loads the HBN dataset that has been preprocessed for the EEG Challenge:\n"
131
+ " - Downsampled from 500Hz to 100Hz\n"
132
+ " - Bandpass filtered (0.5–50 Hz)\n"
133
+ "\n"
134
+ "For full preprocessing details, see:\n"
135
+ " https://github.com/eeg2025/downsample-datasets\n"
136
+ "\n"
137
+ "IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.\n"
138
+ "If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
139
+ "\n",
140
+ UserWarning,
141
+ module="eegdash",
142
+ )
143
+
144
+ super().__init__(
145
+ dataset=RELEASE_TO_OPENNEURO_DATASET_MAP[release],
146
+ query=query,
147
+ cache_dir=cache_dir,
148
+ s3_bucket=s3_bucket,
149
+ _suppress_comp_warning=True,
150
+ **kwargs,
151
+ )
152
+
153
+
154
+ registered_classes = register_openneuro_datasets(
155
+ summary_file=Path(__file__).with_name("dataset_summary.csv"),
156
+ base_class=EEGDashDataset,
157
+ namespace=globals(),
158
+ )
159
+
160
+
161
+ __all__ = ["EEGChallengeDataset"] + list(registered_classes.keys())