eegdash 0.3.6.dev182011805__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eegdash might be problematic. Click here for more details.

@@ -0,0 +1,254 @@
1
+ import logging
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from .const import ALLOWED_QUERY_FIELDS
7
+ from .const import config as data_config
8
+
9
+ logger = logging.getLogger("eegdash")
10
+
11
+ __all__ = [
12
+ "build_query_from_kwargs",
13
+ "load_eeg_attrs_from_bids_file",
14
+ "merge_participants_fields",
15
+ "normalize_key",
16
+ ]
17
+
18
+
19
+ def build_query_from_kwargs(**kwargs) -> dict[str, Any]:
20
+ """Build and validate a MongoDB query from user-friendly keyword arguments.
21
+
22
+ Improvements:
23
+ - Reject None values and empty/whitespace-only strings
24
+ - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
25
+ - Preserve scalars as exact matches
26
+ """
27
+ # 1. Validate that all provided keys are allowed for querying
28
+ unknown_fields = set(kwargs.keys()) - ALLOWED_QUERY_FIELDS
29
+ if unknown_fields:
30
+ raise ValueError(
31
+ f"Unsupported query field(s): {', '.join(sorted(unknown_fields))}. "
32
+ f"Allowed fields are: {', '.join(sorted(ALLOWED_QUERY_FIELDS))}"
33
+ )
34
+
35
+ # 2. Construct the query dictionary
36
+ query = {}
37
+ for key, value in kwargs.items():
38
+ # None is not a valid constraint
39
+ if value is None:
40
+ raise ValueError(
41
+ f"Received None for query parameter '{key}'. Provide a concrete value."
42
+ )
43
+
44
+ # Handle list-like values as multi-constraints
45
+ if isinstance(value, (list, tuple, set)):
46
+ cleaned: list[Any] = []
47
+ for item in value:
48
+ if item is None:
49
+ continue
50
+ if isinstance(item, str):
51
+ item = item.strip()
52
+ if not item:
53
+ continue
54
+ cleaned.append(item)
55
+ # Deduplicate while preserving order
56
+ cleaned = list(dict.fromkeys(cleaned))
57
+ if not cleaned:
58
+ raise ValueError(
59
+ f"Received an empty list for query parameter '{key}'. This is not supported."
60
+ )
61
+ query[key] = {"$in": cleaned}
62
+ else:
63
+ # Scalars: trim strings and validate
64
+ if isinstance(value, str):
65
+ value = value.strip()
66
+ if not value:
67
+ raise ValueError(
68
+ f"Received an empty string for query parameter '{key}'."
69
+ )
70
+ query[key] = value
71
+
72
+ return query
73
+
74
+
75
+ def _get_raw_extensions(bids_file: str, bids_dataset) -> list[str]:
76
+ """Helper to find paths to additional "sidecar" files that may be associated
77
+ with a given main data file in a BIDS dataset; paths are returned as relative to
78
+ the parent dataset path.
79
+
80
+ For example, if the input file is a .set file, this will return the relative path
81
+ to a corresponding .fdt file (if any).
82
+ """
83
+ bids_file = Path(bids_file)
84
+ extensions = {
85
+ ".set": [".set", ".fdt"], # eeglab
86
+ ".edf": [".edf"], # european
87
+ ".vhdr": [".eeg", ".vhdr", ".vmrk", ".dat", ".raw"], # brainvision
88
+ ".bdf": [".bdf"], # biosemi
89
+ }
90
+ return [
91
+ str(bids_dataset._get_relative_bidspath(bids_file.with_suffix(suffix)))
92
+ for suffix in extensions[bids_file.suffix]
93
+ if bids_file.with_suffix(suffix).exists()
94
+ ]
95
+
96
+
97
+ def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any]:
98
+ """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
99
+
100
+ Attributes are at least the ones defined in data_config attributes (set to None if missing),
101
+ but are typically a superset, and include, among others, the paths to relevant
102
+ meta-data files needed to load and interpret the file in question.
103
+
104
+ Parameters
105
+ ----------
106
+ bids_dataset : EEGBIDSDataset
107
+ The BIDS dataset object containing the file.
108
+ bids_file : str
109
+ The path to the BIDS file within the dataset.
110
+
111
+ Returns
112
+ -------
113
+ dict:
114
+ A dictionary representing the metadata record for the given file. This is the
115
+ same format as the records stored in the database.
116
+
117
+ """
118
+ if bids_file not in bids_dataset.files:
119
+ raise ValueError(f"{bids_file} not in {bids_dataset.dataset}")
120
+
121
+ # Initialize attrs with None values for all expected fields
122
+ attrs = {field: None for field in data_config["attributes"].keys()}
123
+
124
+ file = Path(bids_file).name
125
+ dsnumber = bids_dataset.dataset
126
+ # extract openneuro path by finding the first occurrence of the dataset name in the filename and remove the path before that
127
+ openneuro_path = dsnumber + bids_file.split(dsnumber)[1]
128
+
129
+ # Update with actual values where available
130
+ try:
131
+ participants_tsv = bids_dataset.subject_participant_tsv(bids_file)
132
+ except Exception as e:
133
+ logger.error("Error getting participants_tsv: %s", str(e))
134
+ participants_tsv = None
135
+
136
+ try:
137
+ eeg_json = bids_dataset.eeg_json(bids_file)
138
+ except Exception as e:
139
+ logger.error("Error getting eeg_json: %s", str(e))
140
+ eeg_json = None
141
+
142
+ bids_dependencies_files = data_config["bids_dependencies_files"]
143
+ bidsdependencies = []
144
+ for extension in bids_dependencies_files:
145
+ try:
146
+ dep_path = bids_dataset.get_bids_metadata_files(bids_file, extension)
147
+ dep_path = [
148
+ str(bids_dataset.get_relative_bidspath(dep)) for dep in dep_path
149
+ ]
150
+ bidsdependencies.extend(dep_path)
151
+ except Exception:
152
+ pass
153
+
154
+ bidsdependencies.extend(_get_raw_extensions(bids_file, bids_dataset))
155
+
156
+ # Define field extraction functions with error handling
157
+ field_extractors = {
158
+ "data_name": lambda: f"{bids_dataset.dataset}_{file}",
159
+ "dataset": lambda: bids_dataset.dataset,
160
+ "bidspath": lambda: openneuro_path,
161
+ "subject": lambda: bids_dataset.get_bids_file_attribute("subject", bids_file),
162
+ "task": lambda: bids_dataset.get_bids_file_attribute("task", bids_file),
163
+ "session": lambda: bids_dataset.get_bids_file_attribute("session", bids_file),
164
+ "run": lambda: bids_dataset.get_bids_file_attribute("run", bids_file),
165
+ "modality": lambda: bids_dataset.get_bids_file_attribute("modality", bids_file),
166
+ "sampling_frequency": lambda: bids_dataset.get_bids_file_attribute(
167
+ "sfreq", bids_file
168
+ ),
169
+ "nchans": lambda: bids_dataset.get_bids_file_attribute("nchans", bids_file),
170
+ "ntimes": lambda: bids_dataset.get_bids_file_attribute("ntimes", bids_file),
171
+ "participant_tsv": lambda: participants_tsv,
172
+ "eeg_json": lambda: eeg_json,
173
+ "bidsdependencies": lambda: bidsdependencies,
174
+ }
175
+
176
+ # Dynamically populate attrs with error handling
177
+ for field, extractor in field_extractors.items():
178
+ try:
179
+ attrs[field] = extractor()
180
+ except Exception as e:
181
+ logger.error("Error extracting %s : %s", field, str(e))
182
+ attrs[field] = None
183
+
184
+ return attrs
185
+
186
+
187
+ def normalize_key(key: str) -> str:
188
+ """Normalize a metadata key for robust matching.
189
+
190
+ Lowercase and replace non-alphanumeric characters with underscores, then strip
191
+ leading/trailing underscores. This allows tolerant matching such as
192
+ "p-factor" ≈ "p_factor" ≈ "P Factor".
193
+ """
194
+ return re.sub(r"[^a-z0-9]+", "_", str(key).lower()).strip("_")
195
+
196
+
197
+ def merge_participants_fields(
198
+ description: dict[str, Any],
199
+ participants_row: dict[str, Any] | None,
200
+ description_fields: list[str] | None = None,
201
+ ) -> dict[str, Any]:
202
+ """Merge participants.tsv fields into a dataset description dictionary.
203
+
204
+ - Preserves existing entries in ``description`` (no overwrites).
205
+ - Fills requested ``description_fields`` first, preserving their original names.
206
+ - Adds all remaining participants columns generically using normalized keys
207
+ unless a matching requested field already captured them.
208
+
209
+ Parameters
210
+ ----------
211
+ description : dict
212
+ Current description to be enriched in-place and returned.
213
+ participants_row : dict | None
214
+ A mapping of participants.tsv columns for the current subject.
215
+ description_fields : list[str] | None
216
+ Optional list of requested description fields. When provided, matching is
217
+ performed by normalized names; the original requested field names are kept.
218
+
219
+ Returns
220
+ -------
221
+ dict
222
+ The enriched description (same object as input for convenience).
223
+
224
+ """
225
+ if not isinstance(description, dict) or not isinstance(participants_row, dict):
226
+ return description
227
+
228
+ # Normalize participants keys and keep first non-None value per normalized key
229
+ norm_map: dict[str, Any] = {}
230
+ for part_key, part_value in participants_row.items():
231
+ norm_key = normalize_key(part_key)
232
+ if norm_key not in norm_map and part_value is not None:
233
+ norm_map[norm_key] = part_value
234
+
235
+ # Ensure description_fields is a list for matching
236
+ requested = list(description_fields or [])
237
+
238
+ # 1) Fill requested fields first using normalized matching, preserving names
239
+ for key in requested:
240
+ if key in description:
241
+ continue
242
+ requested_norm_key = normalize_key(key)
243
+ if requested_norm_key in norm_map:
244
+ description[key] = norm_map[requested_norm_key]
245
+
246
+ # 2) Add remaining participants columns generically under normalized names,
247
+ # unless a requested field already captured them
248
+ requested_norm = {normalize_key(k) for k in requested}
249
+ for norm_key, part_value in norm_map.items():
250
+ if norm_key in requested_norm:
251
+ continue
252
+ if norm_key not in description:
253
+ description[norm_key] = part_value
254
+ return description
@@ -1,7 +1,15 @@
1
- from pathlib import Path
2
-
3
- from .api import EEGDashDataset
4
- from .registry import register_openneuro_datasets
1
+ ALLOWED_QUERY_FIELDS = {
2
+ "data_name",
3
+ "dataset",
4
+ "subject",
5
+ "task",
6
+ "session",
7
+ "run",
8
+ "modality",
9
+ "sampling_frequency",
10
+ "nchans",
11
+ "ntimes",
12
+ }
5
13
 
6
14
  RELEASE_TO_OPENNEURO_DATASET_MAP = {
7
15
  "R11": "ds005516",
@@ -262,92 +270,37 @@ SUBJECT_MINI_RELEASE_MAP = {
262
270
  ],
263
271
  }
264
272
 
265
-
266
- class EEGChallengeDataset(EEGDashDataset):
267
- def __init__(
268
- self,
269
- release: str,
270
- cache_dir: str,
271
- mini: bool = True,
272
- query: dict | None = None,
273
- s3_bucket: str | None = "s3://nmdatasets/NeurIPS25",
274
- **kwargs,
275
- ):
276
- """Create a new EEGDashDataset from a given query or local BIDS dataset directory
277
- and dataset name. An EEGDashDataset is pooled collection of EEGDashBaseDataset
278
- instances (individual recordings) and is a subclass of braindecode's BaseConcatDataset.
279
-
280
- Parameters
281
- ----------
282
- release: str
283
- Release name. Can be one of ["R1", ..., "R11"]
284
- mini: bool, default True
285
- Whether to use the mini-release version of the dataset. It is recommended
286
- to use the mini version for faster training and evaluation.
287
- query : dict | None
288
- Optionally a dictionary that specifies a query to be executed,
289
- in addition to the dataset (automatically inferred from the release argument).
290
- See EEGDash.find() for details on the query format.
291
- cache_dir : str
292
- A directory where the dataset will be cached locally.
293
- s3_bucket : str | None
294
- An optional S3 bucket URI to use instead of the
295
- default OpenNeuro bucket for loading data files.
296
- kwargs : dict
297
- Additional keyword arguments to be passed to the EEGDashDataset
298
- constructor.
299
-
300
- """
301
- self.release = release
302
- self.mini = mini
303
-
304
- if release not in RELEASE_TO_OPENNEURO_DATASET_MAP:
305
- raise ValueError(
306
- f"Unknown release: {release}, expected one of {list(RELEASE_TO_OPENNEURO_DATASET_MAP.keys())}"
307
- )
308
-
309
- dataset_parameters = []
310
- if isinstance(release, str):
311
- dataset_parameters.append(RELEASE_TO_OPENNEURO_DATASET_MAP[release])
312
- else:
313
- raise ValueError(
314
- f"Unknown release type: {type(release)}, the expected type is str."
315
- )
316
-
317
- if query and "dataset" in query:
318
- raise ValueError(
319
- "Query using the parameters `dataset` with the class EEGChallengeDataset is not possible."
320
- "Please use the release argument instead, or the object EEGDashDataset instead."
321
- )
322
-
323
- if self.mini:
324
- # Disallow mixing subject selection with mini=True since mini already
325
- # applies a predefined subject subset.
326
- if (query and "subject" in query) or ("subject" in kwargs):
327
- raise ValueError(
328
- "Query using the parameters `subject` with the class EEGChallengeDataset and `mini==True` is not possible."
329
- "Please don't use the `subject` selection twice."
330
- "Set `mini=False` to use the `subject` selection."
331
- )
332
- kwargs["subject"] = SUBJECT_MINI_RELEASE_MAP[release]
333
- s3_bucket = f"{s3_bucket}/{release}_mini_L100_bdf"
334
- else:
335
- s3_bucket = f"{s3_bucket}/{release}_L100_bdf"
336
-
337
- super().__init__(
338
- dataset=dataset_parameters,
339
- query=query,
340
- cache_dir=cache_dir,
341
- s3_bucket=s3_bucket,
342
- **kwargs,
343
- )
344
-
345
-
346
- registered_classes = register_openneuro_datasets(
347
- summary_file=Path(__file__).with_name("dataset_summary.csv"),
348
- base_class=EEGDashDataset,
349
- namespace=globals(),
350
- )
351
-
352
-
353
- __all__ = ["EEGChallengeDataset"] + list(registered_classes.keys())
273
+ config = {
274
+ "required_fields": ["data_name"],
275
+ # Default set of user-facing primary record attributes expected in the database. Records
276
+ # where any of these are missing will be loaded with the respective attribute set to None.
277
+ # Additional fields may be returned if they are present in the database, notably bidsdependencies.
278
+ "attributes": {
279
+ "data_name": "str",
280
+ "dataset": "str",
281
+ "bidspath": "str",
282
+ "subject": "str",
283
+ "task": "str",
284
+ "session": "str",
285
+ "run": "str",
286
+ "sampling_frequency": "float",
287
+ "modality": "str",
288
+ "nchans": "int",
289
+ "ntimes": "int", # note: this is really the number of seconds in the data, rounded down
290
+ },
291
+ # queryable descriptive fields for a given recording
292
+ "description_fields": ["subject", "session", "run", "task", "age", "gender", "sex"],
293
+ # list of filenames that may be present in the BIDS dataset directory that are used
294
+ # to load and interpret a given BIDS recording.
295
+ "bids_dependencies_files": [
296
+ "dataset_description.json",
297
+ "participants.tsv",
298
+ "events.tsv",
299
+ "events.json",
300
+ "eeg.json",
301
+ "electrodes.tsv",
302
+ "channels.tsv",
303
+ "coordsystem.json",
304
+ ],
305
+ "accepted_query_fields": ["data_name", "dataset"],
306
+ }