braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,987 @@
1
+ # mypy: ignore-errors
2
+ """
3
+ Hugging Face Hub integration for EEG datasets.
4
+
5
+ This module provides push_to_hub() and pull_from_hub() functionality
6
+ for braindecode datasets, similar to the model Hub integration.
7
+
8
+ .. warning::
9
+ The format is **BIDS-inspired**, not **BIDS-compliant**. The metadata
10
+ files are BIDS-compliant, but the data is stored in Zarr format for
11
+ efficient training, which is not a valid BIDS EEG format.
12
+
13
+ The format follows a BIDS-inspired sourcedata structure:
14
+ - sourcedata/braindecode/
15
+ - dataset_description.json (BIDS-compliant)
16
+ - participants.tsv (BIDS-compliant)
17
+ - dataset.zarr/ (NOT BIDS-compliant - efficient data store)
18
+ - sub-<label>/
19
+ - eeg/
20
+ - *_events.tsv (BIDS-compliant)
21
+ - *_channels.tsv (BIDS-compliant)
22
+ - *_eeg.json (BIDS-compliant)
23
+ """
24
+
25
+ # Authors: Kuntal Kokate
26
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
27
+ #
28
+ # License: BSD (3-clause)
29
+
30
+ import json
31
+ import logging
32
+ import tempfile
33
+ from datetime import timedelta
34
+ from pathlib import Path
35
+ from typing import TYPE_CHECKING, List, Optional, Union
36
+
37
+ import mne
38
+ import numpy as np
39
+ import pandas as pd
40
+ import scipy
41
+ from mne._fiff.meas_info import Info
42
+ from mne.utils import _soft_import
43
+
44
+ if TYPE_CHECKING:
45
+ from ..base import BaseDataset
46
+
47
+ import braindecode
48
+
49
+ # Import registry for dynamic class lookup (avoids circular imports)
50
+ from ..registry import get_dataset_class, get_dataset_type
51
+
52
+ # Hub format and validation utilities
53
+ from . import hub_format, hub_validation
54
+ from .hub_io import (
55
+ _create_compressor,
56
+ _load_eegwindows_from_zarr,
57
+ _load_raw_from_zarr,
58
+ _load_windows_from_zarr,
59
+ _save_eegwindows_to_zarr,
60
+ _save_raw_to_zarr,
61
+ _save_windows_to_zarr,
62
+ )
63
+
64
+ # Lazy import zarr and huggingface_hub
65
+ zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
66
+ huggingface_hub = _soft_import(
67
+ "huggingface_hub", purpose="hugging face integration", strict=False
68
+ )
69
+
70
+ log = logging.getLogger(__name__)
71
+
72
+
73
+ class HubDatasetMixin:
74
+ """
75
+ Mixin class for Hugging Face Hub integration with EEG datasets.
76
+
77
+ This class adds `push_to_hub()` and `pull_from_hub()` methods to
78
+ BaseConcatDataset, enabling easy upload and download of datasets
79
+ to/from the Hugging Face Hub.
80
+
81
+ Examples
82
+ --------
83
+ >>> # Push dataset to Hub
84
+ >>> dataset = NMT(path=path, preload=True)
85
+ >>> dataset.push_to_hub(
86
+ ... repo_id="username/nmt-dataset",
87
+ ... commit_message="Add NMT dataset"
88
+ ... )
89
+ >>>
90
+ >>> # Load dataset from Hub
91
+ >>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
92
+ """
93
+
94
+ datasets: List["BaseDataset"] # Attribute provided by inheriting class
95
+
96
+ def push_to_hub(
97
+ self,
98
+ repo_id: str,
99
+ commit_message: Optional[str] = None,
100
+ private: bool = False,
101
+ token: Optional[str] = None,
102
+ create_pr: bool = False,
103
+ compression: str = "blosc",
104
+ compression_level: int = 5,
105
+ pipeline_name: str = "braindecode",
106
+ ) -> str:
107
+ """
108
+ Upload the dataset to the Hugging Face Hub in BIDS-like Zarr format.
109
+
110
+ The dataset is converted to Zarr format with blosc compression, which provides
111
+ optimal random access performance for PyTorch training. The data is stored
112
+ in a BIDS sourcedata-like structure with events.tsv, channels.tsv,
113
+ and participants.tsv sidecar files.
114
+
115
+ Parameters
116
+ ----------
117
+ repo_id : str
118
+ Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
119
+ commit_message : str | None
120
+ Commit message. If None, a default message is generated.
121
+ private : bool, default=False
122
+ Whether to create a private repository.
123
+ token : str | None
124
+ Hugging Face API token. If None, uses cached token.
125
+ create_pr : bool, default=False
126
+ Whether to create a Pull Request instead of directly committing.
127
+ compression : str, default="blosc"
128
+ Compression algorithm for Zarr. Options: "blosc", "zstd", "gzip", None.
129
+ compression_level : int, default=5
130
+ Compression level (0-9). Level 5 provides optimal balance.
131
+ pipeline_name : str, default="braindecode"
132
+ Name of the processing pipeline for BIDS sourcedata.
133
+
134
+ Returns
135
+ -------
136
+ str
137
+ URL of the uploaded dataset on the Hub.
138
+
139
+ Raises
140
+ ------
141
+ ImportError
142
+ If huggingface-hub is not installed.
143
+ ValueError
144
+ If the dataset is empty or format is invalid.
145
+
146
+ Examples
147
+ --------
148
+ >>> dataset = NMT(path=path, preload=True)
149
+ >>> # Upload with BIDS-like structure
150
+ >>> url = dataset.push_to_hub(
151
+ ... repo_id="myusername/nmt-dataset",
152
+ ... commit_message="Upload NMT EEG dataset"
153
+ ... )
154
+ """
155
+ if huggingface_hub is False or zarr is False:
156
+ raise ImportError(
157
+ "huggingface-hub or zarr is not installed. Install with: "
158
+ "pip install braindecode[hub]"
159
+ )
160
+
161
+ # Create API instance
162
+ _ = huggingface_hub.HfApi(token=token)
163
+
164
+ # Create repository if it doesn't exist
165
+ try:
166
+ huggingface_hub.create_repo(
167
+ repo_id=repo_id,
168
+ token=token,
169
+ private=private,
170
+ repo_type="dataset",
171
+ exist_ok=True,
172
+ )
173
+ except Exception as e:
174
+ raise RuntimeError(f"Failed to create repository: {e}")
175
+
176
+ # Create a temporary directory for upload
177
+ with tempfile.TemporaryDirectory() as tmpdir:
178
+ tmp_path = Path(tmpdir)
179
+
180
+ # Create BIDS-like sourcedata structure
181
+ log.info("Creating BIDS-like sourcedata structure...")
182
+ bids_layout = hub_format.BIDSSourcedataLayout(
183
+ tmp_path, pipeline_name=pipeline_name
184
+ )
185
+ sourcedata_dir = bids_layout.create_structure()
186
+
187
+ # Save dataset_description.json
188
+ bids_layout.save_dataset_description()
189
+
190
+ # Save participants.tsv
191
+ descriptions = [ds.description for ds in self.datasets]
192
+ bids_layout.save_participants(descriptions)
193
+
194
+ # Save BIDS sidecar files for each recording
195
+ self._save_bids_sidecar_files(bids_layout)
196
+
197
+ # Convert dataset to Zarr format inside sourcedata
198
+ log.info("Converting dataset to Zarr format...")
199
+ dataset_path = sourcedata_dir / "dataset.zarr"
200
+
201
+ self._convert_to_zarr_inline(
202
+ dataset_path,
203
+ compression,
204
+ compression_level,
205
+ )
206
+
207
+ # Save dataset metadata (README.md)
208
+ self._save_dataset_card(tmp_path)
209
+
210
+ # Save format info
211
+ format_info_path = tmp_path / "format_info.json"
212
+ with open(format_info_path, "w", encoding="utf-8") as f:
213
+ format_info = self._get_format_info_inline()
214
+ json.dump(
215
+ {
216
+ "format": "zarr",
217
+ "pipeline_name": pipeline_name,
218
+ "compression": compression,
219
+ "compression_level": compression_level,
220
+ "braindecode_version": braindecode.__version__,
221
+ **format_info,
222
+ },
223
+ f,
224
+ indent=2,
225
+ )
226
+
227
+ # Default commit message
228
+ if commit_message is None:
229
+ commit_message = (
230
+ f"Upload EEG dataset in BIDS-like "
231
+ f"Zarr format ({len(self.datasets)} recordings)"
232
+ )
233
+
234
+ # Upload folder to Hub
235
+ log.info(f"Uploading to Hugging Face Hub ({repo_id})...")
236
+ try:
237
+ url = huggingface_hub.upload_folder(
238
+ repo_id=repo_id,
239
+ folder_path=str(tmp_path),
240
+ repo_type="dataset",
241
+ commit_message=commit_message,
242
+ token=token,
243
+ create_pr=create_pr,
244
+ )
245
+ log.info(f"Dataset uploaded successfully to {repo_id}")
246
+ log.info(f"URL: https://huggingface.co/datasets/{repo_id}")
247
+ return url
248
+ except Exception as e:
249
+ raise RuntimeError(f"Failed to upload dataset: {e}")
250
+
251
+ def _save_dataset_card(self, path: Path, bids_inspired: bool = True) -> None:
252
+ """Generate and save a dataset card (README.md) with metadata.
253
+
254
+ Parameters
255
+ ----------
256
+ path : Path
257
+ Directory where README.md will be saved.
258
+ bids_inspired : bool
259
+ Whether to include BIDS-inspired format documentation.
260
+ """
261
+ # Get info, which also validates uniformity across all datasets
262
+ format_info = self._get_format_info_inline()
263
+
264
+ n_recordings = len(self.datasets)
265
+ first_ds = self.datasets[0]
266
+
267
+ # Get dataset-specific info based on type using registry
268
+ dataset_type = get_dataset_type(first_ds)
269
+
270
+ n_windows = format_info["total_samples"]
271
+
272
+ # Compute total duration across all recordings
273
+ total_duration = 0.0
274
+ if dataset_type == "WindowsDataset":
275
+ n_channels = len(first_ds.windows.ch_names)
276
+ data_type = "Windowed (from Epochs object)"
277
+ sfreq = first_ds.windows.info["sfreq"]
278
+ for ds in self.datasets:
279
+ epoch_length = ds.windows.tmax - ds.windows.tmin
280
+ total_duration += len(ds.windows) * epoch_length
281
+ elif dataset_type == "EEGWindowsDataset":
282
+ n_channels = len(first_ds.raw.ch_names)
283
+ sfreq = first_ds.raw.info["sfreq"]
284
+ data_type = "Windowed (from Raw object)"
285
+ for ds in self.datasets:
286
+ total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
287
+ elif dataset_type == "RawDataset":
288
+ n_channels = len(first_ds.raw.ch_names)
289
+ sfreq = first_ds.raw.info["sfreq"]
290
+ data_type = "Continuous (Raw)"
291
+ for ds in self.datasets:
292
+ total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
293
+ else:
294
+ raise TypeError(f"Unsupported dataset type: {dataset_type}")
295
+
296
+ # Create README content and save
297
+ readme_content = _generate_readme_content(
298
+ format_info=format_info,
299
+ n_recordings=n_recordings,
300
+ n_channels=n_channels,
301
+ sfreq=sfreq,
302
+ data_type=data_type,
303
+ n_windows=n_windows,
304
+ total_duration=total_duration,
305
+ )
306
+
307
+ # Save README
308
+ readme_path = path / "README.md"
309
+ with open(readme_path, "w", encoding="utf-8") as f:
310
+ f.write(readme_content)
311
+
312
+ def _save_bids_sidecar_files(
313
+ self, bids_layout: "hub_format.BIDSSourcedataLayout"
314
+ ) -> None:
315
+ """Save BIDS-compliant sidecar files for each recording.
316
+
317
+ This creates events.tsv, channels.tsv, and EEG sidecar JSON files
318
+ for each recording in a BIDS-like directory structure.
319
+
320
+ Parameters
321
+ ----------
322
+ bids_layout : BIDSSourcedataLayout
323
+ BIDS layout object for path generation.
324
+ """
325
+ dataset_type = get_dataset_type(self.datasets[0])
326
+
327
+ for i_ds, ds in enumerate(self.datasets):
328
+ # Get BIDS entities from description
329
+ description = ds.description if ds.description is not None else pd.Series()
330
+
331
+ # Get BIDSPath for this recording using mne_bids
332
+ bids_path = bids_layout.get_bids_path(description)
333
+
334
+ # Create subject directory
335
+ bids_path.mkdir(exist_ok=True)
336
+
337
+ # Get metadata and info based on dataset type
338
+ # Also compute recording_duration, recording_type, and epoch_length
339
+ recording_duration = None
340
+ recording_type = None
341
+ epoch_length = None
342
+
343
+ if dataset_type == "WindowsDataset":
344
+ info = ds.windows.info
345
+ metadata = ds.windows.metadata
346
+ sfreq = info["sfreq"]
347
+ # WindowsDataset contains pre-cut epochs
348
+ recording_type = "epoched"
349
+ # Use MNE's tmax - tmin for epoch length
350
+ epoch_length = ds.windows.tmax - ds.windows.tmin
351
+ # Total duration = number of epochs * epoch length
352
+ n_epochs = len(ds.windows)
353
+ recording_duration = n_epochs * epoch_length
354
+ elif dataset_type == "EEGWindowsDataset":
355
+ info = ds.raw.info
356
+ metadata = ds.metadata
357
+ sfreq = info["sfreq"]
358
+ # EEGWindowsDataset has continuous raw with window metadata
359
+ recording_type = "epoched"
360
+ # Use MNE Raw's duration property
361
+ recording_duration = ds.raw.duration
362
+ # Compute epoch_length from metadata if available
363
+ if metadata is not None and len(metadata) > 0:
364
+ i_start = metadata["i_start_in_trial"].iloc[0]
365
+ i_stop = metadata["i_stop_in_trial"].iloc[0]
366
+ epoch_length = (i_stop - i_start) / sfreq
367
+ elif dataset_type == "RawDataset":
368
+ info = ds.raw.info
369
+ metadata = None
370
+ sfreq = info["sfreq"]
371
+ # RawDataset is continuous
372
+ recording_type = "continuous"
373
+ # Use MNE Raw's duration property
374
+ recording_duration = ds.raw.duration
375
+ else:
376
+ continue
377
+
378
+ # Determine task name from description or BIDSPath
379
+ task_name = bids_path.task or "unknown"
380
+
381
+ # Save BIDS sidecar files using mne_bids BIDSPath
382
+ hub_format.save_bids_sidecar_files(
383
+ bids_path=bids_path,
384
+ info=info,
385
+ metadata=metadata,
386
+ sfreq=sfreq,
387
+ task_name=str(task_name),
388
+ recording_duration=recording_duration,
389
+ recording_type=recording_type,
390
+ epoch_length=epoch_length,
391
+ )
392
+
393
+ log.debug(
394
+ f"Saved BIDS sidecar files for recording {i_ds} to {bids_path.directory}"
395
+ )
396
+
397
+ @classmethod
398
+ def pull_from_hub(
399
+ cls,
400
+ repo_id: str,
401
+ preload: bool = True,
402
+ token: Optional[str] = None,
403
+ cache_dir: Optional[Union[str, Path]] = None,
404
+ force_download: bool = False,
405
+ **kwargs,
406
+ ):
407
+ """
408
+ Load a dataset from the Hugging Face Hub.
409
+
410
+ Parameters
411
+ ----------
412
+ repo_id : str
413
+ Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
414
+ preload : bool, default=True
415
+ Whether to preload the data into memory. If False, uses lazy loading
416
+ (when supported by the format).
417
+ token : str | None
418
+ Hugging Face API token. If None, uses cached token.
419
+ cache_dir : str | Path | None
420
+ Directory to cache the downloaded dataset. If None, uses default
421
+ cache directory (~/.cache/huggingface/datasets).
422
+ force_download : bool, default=False
423
+ Whether to force re-download even if cached.
424
+ **kwargs
425
+ Additional arguments (currently unused).
426
+
427
+ Returns
428
+ -------
429
+ BaseConcatDataset
430
+ The loaded dataset.
431
+
432
+ Raises
433
+ ------
434
+ ImportError
435
+ If huggingface-hub is not installed.
436
+ FileNotFoundError
437
+ If the repository or dataset files are not found.
438
+
439
+ Examples
440
+ --------
441
+ >>> from braindecode.datasets import BaseConcatDataset
442
+ >>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
443
+ >>> print(f"Loaded {len(dataset)} windows")
444
+ >>>
445
+ >>> # Use with PyTorch
446
+ >>> from torch.utils.data import DataLoader
447
+ >>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
448
+ """
449
+ if zarr is False or huggingface_hub is False:
450
+ raise ImportError(
451
+ "huggingface hub functionality is not installed. Install with: "
452
+ "pip install braindecode[hub]"
453
+ )
454
+
455
+ log.info(f"Loading dataset from Hugging Face Hub ({repo_id})...")
456
+
457
+ try:
458
+ # Download the entire dataset directory
459
+ dataset_dir = huggingface_hub.snapshot_download(
460
+ repo_id=repo_id,
461
+ repo_type="dataset",
462
+ token=token,
463
+ cache_dir=cache_dir,
464
+ force_download=force_download,
465
+ )
466
+
467
+ # Load format info
468
+ format_info_path = Path(dataset_dir) / "format_info.json"
469
+ if format_info_path.exists():
470
+ with open(format_info_path, "r") as f:
471
+ format_info = json.load(f)
472
+
473
+ # Verify it's zarr format
474
+ if format_info.get("format") != "zarr":
475
+ raise ValueError(
476
+ f"Dataset format is '{format_info.get('format')}', but only "
477
+ "'zarr' format is supported. Please re-upload the dataset."
478
+ )
479
+ else:
480
+ format_info = {}
481
+
482
+ pipeline_name = format_info.get("pipeline_name", "braindecode")
483
+
484
+ # Find zarr dataset path (try sourcedata, derivatives, then root)
485
+ zarr_path = (
486
+ Path(dataset_dir) / "sourcedata" / pipeline_name / "dataset.zarr"
487
+ )
488
+ if not zarr_path.exists():
489
+ zarr_path = (
490
+ Path(dataset_dir) / "derivatives" / pipeline_name / "dataset.zarr"
491
+ )
492
+ if not zarr_path.exists():
493
+ zarr_path = Path(dataset_dir) / "dataset.zarr"
494
+
495
+ if not zarr_path.exists():
496
+ raise FileNotFoundError(
497
+ f"Zarr dataset not found at {zarr_path}. "
498
+ "The dataset may be in an unsupported format."
499
+ )
500
+
501
+ dataset = cls._load_from_zarr_inline(zarr_path, preload)
502
+
503
+ # Load BIDS metadata if available
504
+ cls._load_bids_metadata(dataset, Path(dataset_dir), pipeline_name)
505
+
506
+ log.info(f"Dataset loaded successfully from {repo_id}")
507
+ log.info(f"Recordings: {len(dataset.datasets)}")
508
+ log.info(
509
+ f"Total windows/samples: {format_info.get('total_samples', 'N/A')}"
510
+ )
511
+
512
+ return dataset
513
+
514
+ except huggingface_hub.utils.HfHubHTTPError as e:
515
+ if e.response.status_code == 404:
516
+ raise FileNotFoundError(
517
+ f"Dataset '{repo_id}' not found on Hugging Face Hub. "
518
+ "Please check the repository ID and ensure it exists."
519
+ )
520
+ else:
521
+ raise RuntimeError(f"Failed to download dataset: {e}")
522
+ except Exception as e:
523
+ raise RuntimeError(f"Failed to load dataset from Hub: {e}")
524
+
525
+ @classmethod
526
+ def _load_bids_metadata(
527
+ cls,
528
+ dataset,
529
+ dataset_dir: Path,
530
+ pipeline_name: str,
531
+ ) -> None:
532
+ """Load BIDS metadata from sidecar files and attach to dataset.
533
+
534
+ Parameters
535
+ ----------
536
+ dataset : BaseConcatDataset
537
+ The loaded dataset to attach metadata to.
538
+ dataset_dir : Path
539
+ Root directory of the downloaded dataset.
540
+ pipeline_name : str
541
+ Name of the processing pipeline.
542
+ """
543
+ # Try sourcedata first, fall back to derivatives for backwards compatibility
544
+ sourcedata_dir = dataset_dir / "sourcedata" / pipeline_name
545
+ if not sourcedata_dir.exists():
546
+ sourcedata_dir = dataset_dir / "derivatives" / pipeline_name
547
+
548
+ # Load participants.tsv if available
549
+ participants_path = sourcedata_dir / "participants.tsv"
550
+ if participants_path.exists():
551
+ try:
552
+ participants_df = pd.read_csv(participants_path, sep="\t")
553
+ # Store as attribute on the concat dataset
554
+ dataset.participants = participants_df
555
+ log.debug(
556
+ f"Loaded participants info for {len(participants_df)} subjects"
557
+ )
558
+ except Exception as e:
559
+ log.warning(f"Failed to load participants.tsv: {e}")
560
+
561
+ # Create layout for path generation
562
+ bids_layout = hub_format.BIDSSourcedataLayout(
563
+ dataset_dir, pipeline_name=pipeline_name
564
+ )
565
+
566
+ # Try to load events.tsv files and attach to individual datasets
567
+ for i_ds, ds in enumerate(dataset.datasets):
568
+ description = ds.description if ds.description is not None else pd.Series()
569
+
570
+ # Get BIDSPath for this recording
571
+ bids_path = bids_layout.get_bids_path(description)
572
+
573
+ # Load events.tsv if available
574
+ events_path = bids_path.copy().update(suffix="events", extension=".tsv")
575
+ if events_path.fpath.exists():
576
+ try:
577
+ events_df = pd.read_csv(events_path.fpath, sep="\t")
578
+ ds.bids_events = events_df
579
+ log.debug(f"Loaded events for recording {i_ds}")
580
+ except Exception as e:
581
+ log.warning(f"Failed to load events for recording {i_ds}: {e}")
582
+
583
+ # Load channels.tsv if available
584
+ channels_path = bids_path.copy().update(suffix="channels", extension=".tsv")
585
+ if channels_path.fpath.exists():
586
+ try:
587
+ channels_df = pd.read_csv(channels_path.fpath, sep="\t")
588
+ ds.bids_channels = channels_df
589
+ log.debug(f"Loaded channels for recording {i_ds}")
590
+ except Exception as e:
591
+ log.warning(f"Failed to load channels for recording {i_ds}: {e}")
592
+
593
+ def _convert_to_zarr_inline(
594
+ self,
595
+ output_path: Path,
596
+ compression: str,
597
+ compression_level: int,
598
+ ) -> None:
599
+ """Convert dataset to Zarr format (inline implementation)."""
600
+
601
+ if zarr is False or huggingface_hub is False:
602
+ raise ImportError(
603
+ "huggingface hub functionality is not installed. Install with: "
604
+ "pip install braindecode[hub]"
605
+ )
606
+
607
+ if output_path.exists():
608
+ raise FileExistsError(
609
+ f"{output_path} already exists. Set overwrite=True to replace it."
610
+ )
611
+
612
+ # Create zarr store (zarr v3 API)
613
+ root = zarr.open(str(output_path), mode="w")
614
+
615
+ # Validate uniformity across all datasets using shared validation
616
+ dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
617
+
618
+ # Keep reference to first dataset for preprocessing kwargs
619
+ first_ds = self.datasets[0]
620
+
621
+ # Store global metadata
622
+ root.attrs["n_datasets"] = len(self.datasets)
623
+ root.attrs["dataset_type"] = dataset_type
624
+ root.attrs["braindecode_version"] = braindecode.__version__
625
+
626
+ # Track dependency versions for reproducibility
627
+ root.attrs["mne_version"] = mne.__version__
628
+ root.attrs["numpy_version"] = np.__version__
629
+ root.attrs["pandas_version"] = pd.__version__
630
+ root.attrs["zarr_version"] = zarr.__version__
631
+ root.attrs["scipy_version"] = scipy.__version__
632
+
633
+ # Save preprocessing kwargs (check first dataset, assuming uniform preprocessing)
634
+ # These are typically set by windowing functions on individual datasets
635
+ for kwarg_name in [
636
+ "raw_preproc_kwargs",
637
+ "window_kwargs",
638
+ "window_preproc_kwargs",
639
+ ]:
640
+ # Check first dataset for these attributes
641
+ if hasattr(first_ds, kwarg_name):
642
+ kwargs = getattr(first_ds, kwarg_name)
643
+ if kwargs:
644
+ root.attrs[kwarg_name] = json.dumps(kwargs)
645
+
646
+ # Create compressor
647
+ compressor = _create_compressor(compression, compression_level)
648
+
649
+ # Save each recording
650
+ for i_ds, ds in enumerate(self.datasets):
651
+ grp = root.create_group(f"recording_{i_ds}")
652
+
653
+ if dataset_type == "WindowsDataset":
654
+ # Extract data from WindowsDataset
655
+ data = ds.windows.get_data()
656
+ metadata = ds.windows.metadata
657
+ description = ds.description
658
+ info_dict = ds.windows.info.to_json_dict()
659
+ target_name = ds.target_name if hasattr(ds, "target_name") else None
660
+
661
+ # Save using inlined function
662
+ _save_windows_to_zarr(
663
+ grp, data, metadata, description, info_dict, compressor, target_name
664
+ )
665
+
666
+ elif dataset_type == "EEGWindowsDataset":
667
+ # Get continuous raw data and metadata from EEGWindowsDataset
668
+ raw = ds.raw
669
+ metadata = ds.metadata
670
+ description = ds.description
671
+ info_dict = ds.raw.info.to_json_dict()
672
+ targets_from = ds.targets_from
673
+ last_target_only = ds.last_target_only
674
+
675
+ # Save using inlined function (saves continuous raw directly)
676
+ _save_eegwindows_to_zarr(
677
+ grp,
678
+ raw,
679
+ metadata,
680
+ description,
681
+ info_dict,
682
+ targets_from,
683
+ last_target_only,
684
+ compressor,
685
+ )
686
+
687
+ elif dataset_type == "RawDataset":
688
+ # Get continuous raw data from RawDataset
689
+ raw = ds.raw
690
+ description = ds.description
691
+ info_dict = ds.raw.info.to_json_dict()
692
+ target_name = ds.target_name if hasattr(ds, "target_name") else None
693
+
694
+ # Save using inlined function
695
+ _save_raw_to_zarr(
696
+ grp, raw, description, info_dict, target_name, compressor
697
+ )
698
+
699
+ def _get_format_info_inline(self):
700
+ """Get format info (inline implementation).
701
+
702
+ This is an inline version of hub_formats.get_format_info() that avoids
703
+ circular import.
704
+ """
705
+ if len(self.datasets) == 0:
706
+ raise ValueError("Cannot get format info for empty dataset")
707
+
708
+ # Validate uniformity across all datasets using shared validation
709
+ dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
710
+
711
+ # Calculate dataset size
712
+ # BaseConcatDataset's __len__ already sums len(ds) for all datasets
713
+ total_samples = len(self)
714
+ total_size_mb = 0
715
+
716
+ for ds in self.datasets:
717
+ if dataset_type == "WindowsDataset":
718
+ # Use MNE's internal _size property to avoid loading data
719
+ total_size_mb += ds.windows._size / (1024 * 1024)
720
+ elif dataset_type == "EEGWindowsDataset":
721
+ # Use raw object's size (not extracted windows)
722
+ total_size_mb += ds.raw._size / (1024 * 1024)
723
+ elif dataset_type == "RawDataset":
724
+ total_size_mb += ds.raw._size / (1024 * 1024)
725
+
726
+ n_recordings = len(self.datasets)
727
+
728
+ return {
729
+ "n_recordings": n_recordings,
730
+ "total_samples": total_samples,
731
+ "total_size_mb": round(total_size_mb, 2),
732
+ }
733
+
734
+ @staticmethod
735
+ def _load_from_zarr_inline(input_path: Path, preload: bool):
736
+ """Load dataset from Zarr format (inline implementation).
737
+
738
+ This is an inline version of hub_formats.load_from_zarr() that avoids
739
+ circular import by using hub_formats_core directly.
740
+ """
741
+ if not input_path.exists():
742
+ raise FileNotFoundError(f"{input_path} does not exist.")
743
+
744
+ # Open zarr store (zarr v3 API)
745
+ root = zarr.open(str(input_path), mode="r")
746
+
747
+ n_datasets = root.attrs["n_datasets"]
748
+ dataset_type = root.attrs["dataset_type"]
749
+
750
+ # Get dataset classes from registry
751
+ WindowsDataset = get_dataset_class("WindowsDataset")
752
+ EEGWindowsDataset = get_dataset_class("EEGWindowsDataset")
753
+ RawDataset = get_dataset_class("RawDataset")
754
+ BaseConcatDataset = get_dataset_class("BaseConcatDataset")
755
+
756
+ datasets = []
757
+ for i_ds in range(n_datasets):
758
+ grp = root[f"recording_{i_ds}"]
759
+
760
+ if dataset_type == "WindowsDataset":
761
+ # Load using inlined function
762
+ data, metadata, description, info_dict, target_name = (
763
+ _load_windows_from_zarr(grp, preload)
764
+ )
765
+
766
+ # Convert to MNE objects and create dataset
767
+ info = Info.from_json_dict(info_dict)
768
+ events = np.column_stack(
769
+ [
770
+ metadata["i_start_in_trial"].values,
771
+ np.zeros(len(metadata), dtype=int),
772
+ metadata["target"].values,
773
+ ]
774
+ )
775
+ epochs = mne.EpochsArray(data, info, events=events, metadata=metadata)
776
+ ds = WindowsDataset(epochs, description)
777
+ if target_name is not None:
778
+ ds.target_name = target_name
779
+
780
+ elif dataset_type == "EEGWindowsDataset":
781
+ # Load using inlined function
782
+ (
783
+ data,
784
+ metadata,
785
+ description,
786
+ info_dict,
787
+ targets_from,
788
+ last_target_only,
789
+ ) = _load_eegwindows_from_zarr(grp, preload)
790
+
791
+ # Convert to MNE objects and create dataset
792
+ # Data is already in continuous format [n_channels, n_timepoints]
793
+ info = Info.from_json_dict(info_dict)
794
+ raw = mne.io.RawArray(data, info)
795
+ ds = EEGWindowsDataset(
796
+ raw=raw,
797
+ metadata=metadata,
798
+ description=description,
799
+ targets_from=targets_from,
800
+ last_target_only=last_target_only,
801
+ )
802
+
803
+ elif dataset_type == "RawDataset":
804
+ # Load using inlined function
805
+ data, description, info_dict, target_name = _load_raw_from_zarr(
806
+ grp, preload
807
+ )
808
+
809
+ # Convert to MNE objects and create dataset
810
+ # Data is in continuous format [n_channels, n_timepoints]
811
+ info = Info.from_json_dict(info_dict)
812
+ raw = mne.io.RawArray(data, info)
813
+ ds = RawDataset(raw, description)
814
+ if target_name is not None:
815
+ ds.target_name = target_name
816
+
817
+ else:
818
+ raise ValueError(f"Unsupported dataset_type: {dataset_type}")
819
+
820
+ datasets.append(ds)
821
+
822
+ # Create concat dataset
823
+ concat_ds = BaseConcatDataset(datasets)
824
+
825
+ # Restore preprocessing kwargs (set on individual datasets, not concat)
826
+ for kwarg_name in [
827
+ "raw_preproc_kwargs",
828
+ "window_kwargs",
829
+ "window_preproc_kwargs",
830
+ ]:
831
+ if kwarg_name in root.attrs:
832
+ kwargs = json.loads(root.attrs[kwarg_name])
833
+ # Set on each individual dataset (where they were originally stored)
834
+ for ds in datasets:
835
+ setattr(ds, kwarg_name, kwargs)
836
+
837
+ return concat_ds
838
+
839
+
840
+ def _generate_readme_content(
841
+ format_info,
842
+ n_recordings: int,
843
+ n_channels: int,
844
+ sfreq,
845
+ data_type: str,
846
+ n_windows: int,
847
+ total_duration: float | None = None,
848
+ format: str = "zarr",
849
+ ):
850
+ """Generate README.md content for a dataset uploaded to the Hub.
851
+
852
+ Parameters
853
+ ----------
854
+ format_info : dict
855
+ Dictionary containing format metadata (e.g., total_size_mb).
856
+ n_recordings : int
857
+ Number of recordings in the dataset.
858
+ n_channels : int
859
+ Number of EEG channels.
860
+ sfreq : float or None
861
+ Sampling frequency in Hz.
862
+ data_type : str
863
+ Type of dataset (e.g., "Windowed", "Continuous").
864
+ n_windows : int
865
+ Number of windows/samples in the dataset.
866
+ total_duration : float or None
867
+ Total duration in seconds across all recordings.
868
+ format : str
869
+ Storage format (default: "zarr").
870
+
871
+ Returns
872
+ -------
873
+ str
874
+ Markdown content for the README.md file.
875
+ """
876
+ total_size_mb = (
877
+ format_info.get("total_size_mb", 0.0) if isinstance(format_info, dict) else 0.0
878
+ )
879
+ sfreq_str = f"{sfreq:g}" if sfreq is not None else "N/A"
880
+
881
+ duration_str = (
882
+ str(timedelta(seconds=int(total_duration))) if total_duration else "N/A"
883
+ )
884
+
885
+ return f"""---
886
+ tags:
887
+ - braindecode
888
+ - eeg
889
+ - neuroscience
890
+ - brain-computer-interface
891
+ - deep-learning
892
+ license: unknown
893
+ ---
894
+
895
+ # EEG Dataset
896
+
897
+ This dataset was created using [braindecode](https://braindecode.org), a deep
898
+ learning library for EEG/MEG/ECoG signals.
899
+
900
+ ## Dataset Information
901
+
902
+ | Property | Value |
903
+ |----------|------:|
904
+ | Recordings | {n_recordings} |
905
+ | Type | {data_type} |
906
+ | Channels | {n_channels} |
907
+ | Sampling frequency | {sfreq_str} Hz |
908
+ | Total duration | {duration_str} |
909
+ | Windows/samples | {n_windows:,} |
910
+ | Size | {total_size_mb:.2f} MB |
911
+ | Format | {format} |
912
+
913
+ ## Quick Start
914
+
915
+ ```python
916
+ from braindecode.datasets import BaseConcatDataset
917
+
918
+ # Load from Hugging Face Hub
919
+ dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
920
+
921
+ # Access a sample
922
+ X, y, metainfo = dataset[0]
923
+ # X: EEG data [n_channels, n_times]
924
+ # y: target label
925
+ # metainfo: window indices
926
+ ```
927
+
928
+ ## Training with PyTorch
929
+
930
+ ```python
931
+ from torch.utils.data import DataLoader
932
+
933
+ loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
934
+
935
+ for X, y, metainfo in loader:
936
+ # X: [batch_size, n_channels, n_times]
937
+ # y: [batch_size]
938
+ pass # Your training code
939
+ ```
940
+
941
+ ## BIDS-inspired Structure
942
+
943
+ This dataset uses a **BIDS-inspired** organization. Metadata files follow BIDS
944
+ conventions, while data is stored in Zarr format for efficient deep learning.
945
+
946
+ **BIDS-style metadata:**
947
+ - `dataset_description.json` - Dataset information
948
+ - `participants.tsv` - Subject metadata
949
+ - `*_events.tsv` - Trial/window events
950
+ - `*_channels.tsv` - Channel information
951
+ - `*_eeg.json` - Recording parameters
952
+
953
+ **Data storage:**
954
+ - `dataset.zarr/` - Zarr format (optimized for random access)
955
+
956
+ ```
957
+ sourcedata/braindecode/
958
+ ├── dataset_description.json
959
+ ├── participants.tsv
960
+ ├── dataset.zarr/
961
+ └── sub-<label>/
962
+ └── eeg/
963
+ ├── *_events.tsv
964
+ ├── *_channels.tsv
965
+ └── *_eeg.json
966
+ ```
967
+
968
+ ### Accessing Metadata
969
+
970
+ ```python
971
+ # Participants info
972
+ if hasattr(dataset, "participants"):
973
+ print(dataset.participants)
974
+
975
+ # Events for a recording
976
+ if hasattr(dataset.datasets[0], "bids_events"):
977
+ print(dataset.datasets[0].bids_events)
978
+
979
+ # Channel info
980
+ if hasattr(dataset.datasets[0], "bids_channels"):
981
+ print(dataset.datasets[0].bids_channels)
982
+ ```
983
+
984
+ ---
985
+
986
+ *Created with [braindecode](https://braindecode.org)*
987
+ """