braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,987 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
"""
|
|
3
|
+
Hugging Face Hub integration for EEG datasets.
|
|
4
|
+
|
|
5
|
+
This module provides push_to_hub() and pull_from_hub() functionality
|
|
6
|
+
for braindecode datasets, similar to the model Hub integration.
|
|
7
|
+
|
|
8
|
+
.. warning::
|
|
9
|
+
The format is **BIDS-inspired**, not **BIDS-compliant**. The metadata
|
|
10
|
+
files are BIDS-compliant, but the data is stored in Zarr format for
|
|
11
|
+
efficient training, which is not a valid BIDS EEG format.
|
|
12
|
+
|
|
13
|
+
The format follows a BIDS-inspired sourcedata structure:
|
|
14
|
+
- sourcedata/braindecode/
|
|
15
|
+
- dataset_description.json (BIDS-compliant)
|
|
16
|
+
- participants.tsv (BIDS-compliant)
|
|
17
|
+
- dataset.zarr/ (NOT BIDS-compliant - efficient data store)
|
|
18
|
+
- sub-<label>/
|
|
19
|
+
- eeg/
|
|
20
|
+
- *_events.tsv (BIDS-compliant)
|
|
21
|
+
- *_channels.tsv (BIDS-compliant)
|
|
22
|
+
- *_eeg.json (BIDS-compliant)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Authors: Kuntal Kokate
|
|
26
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
27
|
+
#
|
|
28
|
+
# License: BSD (3-clause)
|
|
29
|
+
|
|
30
|
+
import json
|
|
31
|
+
import logging
|
|
32
|
+
import tempfile
|
|
33
|
+
from datetime import timedelta
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
|
36
|
+
|
|
37
|
+
import mne
|
|
38
|
+
import numpy as np
|
|
39
|
+
import pandas as pd
|
|
40
|
+
import scipy
|
|
41
|
+
from mne._fiff.meas_info import Info
|
|
42
|
+
from mne.utils import _soft_import
|
|
43
|
+
|
|
44
|
+
if TYPE_CHECKING:
|
|
45
|
+
from ..base import BaseDataset
|
|
46
|
+
|
|
47
|
+
import braindecode
|
|
48
|
+
|
|
49
|
+
# Import registry for dynamic class lookup (avoids circular imports)
|
|
50
|
+
from ..registry import get_dataset_class, get_dataset_type
|
|
51
|
+
|
|
52
|
+
# Hub format and validation utilities
|
|
53
|
+
from . import hub_format, hub_validation
|
|
54
|
+
from .hub_io import (
|
|
55
|
+
_create_compressor,
|
|
56
|
+
_load_eegwindows_from_zarr,
|
|
57
|
+
_load_raw_from_zarr,
|
|
58
|
+
_load_windows_from_zarr,
|
|
59
|
+
_save_eegwindows_to_zarr,
|
|
60
|
+
_save_raw_to_zarr,
|
|
61
|
+
_save_windows_to_zarr,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Lazy import zarr and huggingface_hub
|
|
65
|
+
zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
|
|
66
|
+
huggingface_hub = _soft_import(
|
|
67
|
+
"huggingface_hub", purpose="hugging face integration", strict=False
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
log = logging.getLogger(__name__)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class HubDatasetMixin:
|
|
74
|
+
"""
|
|
75
|
+
Mixin class for Hugging Face Hub integration with EEG datasets.
|
|
76
|
+
|
|
77
|
+
This class adds `push_to_hub()` and `pull_from_hub()` methods to
|
|
78
|
+
BaseConcatDataset, enabling easy upload and download of datasets
|
|
79
|
+
to/from the Hugging Face Hub.
|
|
80
|
+
|
|
81
|
+
Examples
|
|
82
|
+
--------
|
|
83
|
+
>>> # Push dataset to Hub
|
|
84
|
+
>>> dataset = NMT(path=path, preload=True)
|
|
85
|
+
>>> dataset.push_to_hub(
|
|
86
|
+
... repo_id="username/nmt-dataset",
|
|
87
|
+
... commit_message="Add NMT dataset"
|
|
88
|
+
... )
|
|
89
|
+
>>>
|
|
90
|
+
>>> # Load dataset from Hub
|
|
91
|
+
>>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
datasets: List["BaseDataset"] # Attribute provided by inheriting class
|
|
95
|
+
|
|
96
|
+
def push_to_hub(
|
|
97
|
+
self,
|
|
98
|
+
repo_id: str,
|
|
99
|
+
commit_message: Optional[str] = None,
|
|
100
|
+
private: bool = False,
|
|
101
|
+
token: Optional[str] = None,
|
|
102
|
+
create_pr: bool = False,
|
|
103
|
+
compression: str = "blosc",
|
|
104
|
+
compression_level: int = 5,
|
|
105
|
+
pipeline_name: str = "braindecode",
|
|
106
|
+
) -> str:
|
|
107
|
+
"""
|
|
108
|
+
Upload the dataset to the Hugging Face Hub in BIDS-like Zarr format.
|
|
109
|
+
|
|
110
|
+
The dataset is converted to Zarr format with blosc compression, which provides
|
|
111
|
+
optimal random access performance for PyTorch training. The data is stored
|
|
112
|
+
in a BIDS sourcedata-like structure with events.tsv, channels.tsv,
|
|
113
|
+
and participants.tsv sidecar files.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
repo_id : str
|
|
118
|
+
Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
|
|
119
|
+
commit_message : str | None
|
|
120
|
+
Commit message. If None, a default message is generated.
|
|
121
|
+
private : bool, default=False
|
|
122
|
+
Whether to create a private repository.
|
|
123
|
+
token : str | None
|
|
124
|
+
Hugging Face API token. If None, uses cached token.
|
|
125
|
+
create_pr : bool, default=False
|
|
126
|
+
Whether to create a Pull Request instead of directly committing.
|
|
127
|
+
compression : str, default="blosc"
|
|
128
|
+
Compression algorithm for Zarr. Options: "blosc", "zstd", "gzip", None.
|
|
129
|
+
compression_level : int, default=5
|
|
130
|
+
Compression level (0-9). Level 5 provides optimal balance.
|
|
131
|
+
pipeline_name : str, default="braindecode"
|
|
132
|
+
Name of the processing pipeline for BIDS sourcedata.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
str
|
|
137
|
+
URL of the uploaded dataset on the Hub.
|
|
138
|
+
|
|
139
|
+
Raises
|
|
140
|
+
------
|
|
141
|
+
ImportError
|
|
142
|
+
If huggingface-hub is not installed.
|
|
143
|
+
ValueError
|
|
144
|
+
If the dataset is empty or format is invalid.
|
|
145
|
+
|
|
146
|
+
Examples
|
|
147
|
+
--------
|
|
148
|
+
>>> dataset = NMT(path=path, preload=True)
|
|
149
|
+
>>> # Upload with BIDS-like structure
|
|
150
|
+
>>> url = dataset.push_to_hub(
|
|
151
|
+
... repo_id="myusername/nmt-dataset",
|
|
152
|
+
... commit_message="Upload NMT EEG dataset"
|
|
153
|
+
... )
|
|
154
|
+
"""
|
|
155
|
+
if huggingface_hub is False or zarr is False:
|
|
156
|
+
raise ImportError(
|
|
157
|
+
"huggingface-hub or zarr is not installed. Install with: "
|
|
158
|
+
"pip install braindecode[hub]"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Create API instance
|
|
162
|
+
_ = huggingface_hub.HfApi(token=token)
|
|
163
|
+
|
|
164
|
+
# Create repository if it doesn't exist
|
|
165
|
+
try:
|
|
166
|
+
huggingface_hub.create_repo(
|
|
167
|
+
repo_id=repo_id,
|
|
168
|
+
token=token,
|
|
169
|
+
private=private,
|
|
170
|
+
repo_type="dataset",
|
|
171
|
+
exist_ok=True,
|
|
172
|
+
)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
raise RuntimeError(f"Failed to create repository: {e}")
|
|
175
|
+
|
|
176
|
+
# Create a temporary directory for upload
|
|
177
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
178
|
+
tmp_path = Path(tmpdir)
|
|
179
|
+
|
|
180
|
+
# Create BIDS-like sourcedata structure
|
|
181
|
+
log.info("Creating BIDS-like sourcedata structure...")
|
|
182
|
+
bids_layout = hub_format.BIDSSourcedataLayout(
|
|
183
|
+
tmp_path, pipeline_name=pipeline_name
|
|
184
|
+
)
|
|
185
|
+
sourcedata_dir = bids_layout.create_structure()
|
|
186
|
+
|
|
187
|
+
# Save dataset_description.json
|
|
188
|
+
bids_layout.save_dataset_description()
|
|
189
|
+
|
|
190
|
+
# Save participants.tsv
|
|
191
|
+
descriptions = [ds.description for ds in self.datasets]
|
|
192
|
+
bids_layout.save_participants(descriptions)
|
|
193
|
+
|
|
194
|
+
# Save BIDS sidecar files for each recording
|
|
195
|
+
self._save_bids_sidecar_files(bids_layout)
|
|
196
|
+
|
|
197
|
+
# Convert dataset to Zarr format inside sourcedata
|
|
198
|
+
log.info("Converting dataset to Zarr format...")
|
|
199
|
+
dataset_path = sourcedata_dir / "dataset.zarr"
|
|
200
|
+
|
|
201
|
+
self._convert_to_zarr_inline(
|
|
202
|
+
dataset_path,
|
|
203
|
+
compression,
|
|
204
|
+
compression_level,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Save dataset metadata (README.md)
|
|
208
|
+
self._save_dataset_card(tmp_path)
|
|
209
|
+
|
|
210
|
+
# Save format info
|
|
211
|
+
format_info_path = tmp_path / "format_info.json"
|
|
212
|
+
with open(format_info_path, "w", encoding="utf-8") as f:
|
|
213
|
+
format_info = self._get_format_info_inline()
|
|
214
|
+
json.dump(
|
|
215
|
+
{
|
|
216
|
+
"format": "zarr",
|
|
217
|
+
"pipeline_name": pipeline_name,
|
|
218
|
+
"compression": compression,
|
|
219
|
+
"compression_level": compression_level,
|
|
220
|
+
"braindecode_version": braindecode.__version__,
|
|
221
|
+
**format_info,
|
|
222
|
+
},
|
|
223
|
+
f,
|
|
224
|
+
indent=2,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Default commit message
|
|
228
|
+
if commit_message is None:
|
|
229
|
+
commit_message = (
|
|
230
|
+
f"Upload EEG dataset in BIDS-like "
|
|
231
|
+
f"Zarr format ({len(self.datasets)} recordings)"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Upload folder to Hub
|
|
235
|
+
log.info(f"Uploading to Hugging Face Hub ({repo_id})...")
|
|
236
|
+
try:
|
|
237
|
+
url = huggingface_hub.upload_folder(
|
|
238
|
+
repo_id=repo_id,
|
|
239
|
+
folder_path=str(tmp_path),
|
|
240
|
+
repo_type="dataset",
|
|
241
|
+
commit_message=commit_message,
|
|
242
|
+
token=token,
|
|
243
|
+
create_pr=create_pr,
|
|
244
|
+
)
|
|
245
|
+
log.info(f"Dataset uploaded successfully to {repo_id}")
|
|
246
|
+
log.info(f"URL: https://huggingface.co/datasets/{repo_id}")
|
|
247
|
+
return url
|
|
248
|
+
except Exception as e:
|
|
249
|
+
raise RuntimeError(f"Failed to upload dataset: {e}")
|
|
250
|
+
|
|
251
|
+
def _save_dataset_card(self, path: Path, bids_inspired: bool = True) -> None:
|
|
252
|
+
"""Generate and save a dataset card (README.md) with metadata.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
path : Path
|
|
257
|
+
Directory where README.md will be saved.
|
|
258
|
+
bids_inspired : bool
|
|
259
|
+
Whether to include BIDS-inspired format documentation.
|
|
260
|
+
"""
|
|
261
|
+
# Get info, which also validates uniformity across all datasets
|
|
262
|
+
format_info = self._get_format_info_inline()
|
|
263
|
+
|
|
264
|
+
n_recordings = len(self.datasets)
|
|
265
|
+
first_ds = self.datasets[0]
|
|
266
|
+
|
|
267
|
+
# Get dataset-specific info based on type using registry
|
|
268
|
+
dataset_type = get_dataset_type(first_ds)
|
|
269
|
+
|
|
270
|
+
n_windows = format_info["total_samples"]
|
|
271
|
+
|
|
272
|
+
# Compute total duration across all recordings
|
|
273
|
+
total_duration = 0.0
|
|
274
|
+
if dataset_type == "WindowsDataset":
|
|
275
|
+
n_channels = len(first_ds.windows.ch_names)
|
|
276
|
+
data_type = "Windowed (from Epochs object)"
|
|
277
|
+
sfreq = first_ds.windows.info["sfreq"]
|
|
278
|
+
for ds in self.datasets:
|
|
279
|
+
epoch_length = ds.windows.tmax - ds.windows.tmin
|
|
280
|
+
total_duration += len(ds.windows) * epoch_length
|
|
281
|
+
elif dataset_type == "EEGWindowsDataset":
|
|
282
|
+
n_channels = len(first_ds.raw.ch_names)
|
|
283
|
+
sfreq = first_ds.raw.info["sfreq"]
|
|
284
|
+
data_type = "Windowed (from Raw object)"
|
|
285
|
+
for ds in self.datasets:
|
|
286
|
+
total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
|
|
287
|
+
elif dataset_type == "RawDataset":
|
|
288
|
+
n_channels = len(first_ds.raw.ch_names)
|
|
289
|
+
sfreq = first_ds.raw.info["sfreq"]
|
|
290
|
+
data_type = "Continuous (Raw)"
|
|
291
|
+
for ds in self.datasets:
|
|
292
|
+
total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
|
|
293
|
+
else:
|
|
294
|
+
raise TypeError(f"Unsupported dataset type: {dataset_type}")
|
|
295
|
+
|
|
296
|
+
# Create README content and save
|
|
297
|
+
readme_content = _generate_readme_content(
|
|
298
|
+
format_info=format_info,
|
|
299
|
+
n_recordings=n_recordings,
|
|
300
|
+
n_channels=n_channels,
|
|
301
|
+
sfreq=sfreq,
|
|
302
|
+
data_type=data_type,
|
|
303
|
+
n_windows=n_windows,
|
|
304
|
+
total_duration=total_duration,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Save README
|
|
308
|
+
readme_path = path / "README.md"
|
|
309
|
+
with open(readme_path, "w", encoding="utf-8") as f:
|
|
310
|
+
f.write(readme_content)
|
|
311
|
+
|
|
312
|
+
def _save_bids_sidecar_files(
|
|
313
|
+
self, bids_layout: "hub_format.BIDSSourcedataLayout"
|
|
314
|
+
) -> None:
|
|
315
|
+
"""Save BIDS-compliant sidecar files for each recording.
|
|
316
|
+
|
|
317
|
+
This creates events.tsv, channels.tsv, and EEG sidecar JSON files
|
|
318
|
+
for each recording in a BIDS-like directory structure.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
bids_layout : BIDSSourcedataLayout
|
|
323
|
+
BIDS layout object for path generation.
|
|
324
|
+
"""
|
|
325
|
+
dataset_type = get_dataset_type(self.datasets[0])
|
|
326
|
+
|
|
327
|
+
for i_ds, ds in enumerate(self.datasets):
|
|
328
|
+
# Get BIDS entities from description
|
|
329
|
+
description = ds.description if ds.description is not None else pd.Series()
|
|
330
|
+
|
|
331
|
+
# Get BIDSPath for this recording using mne_bids
|
|
332
|
+
bids_path = bids_layout.get_bids_path(description)
|
|
333
|
+
|
|
334
|
+
# Create subject directory
|
|
335
|
+
bids_path.mkdir(exist_ok=True)
|
|
336
|
+
|
|
337
|
+
# Get metadata and info based on dataset type
|
|
338
|
+
# Also compute recording_duration, recording_type, and epoch_length
|
|
339
|
+
recording_duration = None
|
|
340
|
+
recording_type = None
|
|
341
|
+
epoch_length = None
|
|
342
|
+
|
|
343
|
+
if dataset_type == "WindowsDataset":
|
|
344
|
+
info = ds.windows.info
|
|
345
|
+
metadata = ds.windows.metadata
|
|
346
|
+
sfreq = info["sfreq"]
|
|
347
|
+
# WindowsDataset contains pre-cut epochs
|
|
348
|
+
recording_type = "epoched"
|
|
349
|
+
# Use MNE's tmax - tmin for epoch length
|
|
350
|
+
epoch_length = ds.windows.tmax - ds.windows.tmin
|
|
351
|
+
# Total duration = number of epochs * epoch length
|
|
352
|
+
n_epochs = len(ds.windows)
|
|
353
|
+
recording_duration = n_epochs * epoch_length
|
|
354
|
+
elif dataset_type == "EEGWindowsDataset":
|
|
355
|
+
info = ds.raw.info
|
|
356
|
+
metadata = ds.metadata
|
|
357
|
+
sfreq = info["sfreq"]
|
|
358
|
+
# EEGWindowsDataset has continuous raw with window metadata
|
|
359
|
+
recording_type = "epoched"
|
|
360
|
+
# Use MNE Raw's duration property
|
|
361
|
+
recording_duration = ds.raw.duration
|
|
362
|
+
# Compute epoch_length from metadata if available
|
|
363
|
+
if metadata is not None and len(metadata) > 0:
|
|
364
|
+
i_start = metadata["i_start_in_trial"].iloc[0]
|
|
365
|
+
i_stop = metadata["i_stop_in_trial"].iloc[0]
|
|
366
|
+
epoch_length = (i_stop - i_start) / sfreq
|
|
367
|
+
elif dataset_type == "RawDataset":
|
|
368
|
+
info = ds.raw.info
|
|
369
|
+
metadata = None
|
|
370
|
+
sfreq = info["sfreq"]
|
|
371
|
+
# RawDataset is continuous
|
|
372
|
+
recording_type = "continuous"
|
|
373
|
+
# Use MNE Raw's duration property
|
|
374
|
+
recording_duration = ds.raw.duration
|
|
375
|
+
else:
|
|
376
|
+
continue
|
|
377
|
+
|
|
378
|
+
# Determine task name from description or BIDSPath
|
|
379
|
+
task_name = bids_path.task or "unknown"
|
|
380
|
+
|
|
381
|
+
# Save BIDS sidecar files using mne_bids BIDSPath
|
|
382
|
+
hub_format.save_bids_sidecar_files(
|
|
383
|
+
bids_path=bids_path,
|
|
384
|
+
info=info,
|
|
385
|
+
metadata=metadata,
|
|
386
|
+
sfreq=sfreq,
|
|
387
|
+
task_name=str(task_name),
|
|
388
|
+
recording_duration=recording_duration,
|
|
389
|
+
recording_type=recording_type,
|
|
390
|
+
epoch_length=epoch_length,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
log.debug(
|
|
394
|
+
f"Saved BIDS sidecar files for recording {i_ds} to {bids_path.directory}"
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
@classmethod
|
|
398
|
+
def pull_from_hub(
|
|
399
|
+
cls,
|
|
400
|
+
repo_id: str,
|
|
401
|
+
preload: bool = True,
|
|
402
|
+
token: Optional[str] = None,
|
|
403
|
+
cache_dir: Optional[Union[str, Path]] = None,
|
|
404
|
+
force_download: bool = False,
|
|
405
|
+
**kwargs,
|
|
406
|
+
):
|
|
407
|
+
"""
|
|
408
|
+
Load a dataset from the Hugging Face Hub.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
repo_id : str
|
|
413
|
+
Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
|
|
414
|
+
preload : bool, default=True
|
|
415
|
+
Whether to preload the data into memory. If False, uses lazy loading
|
|
416
|
+
(when supported by the format).
|
|
417
|
+
token : str | None
|
|
418
|
+
Hugging Face API token. If None, uses cached token.
|
|
419
|
+
cache_dir : str | Path | None
|
|
420
|
+
Directory to cache the downloaded dataset. If None, uses default
|
|
421
|
+
cache directory (~/.cache/huggingface/datasets).
|
|
422
|
+
force_download : bool, default=False
|
|
423
|
+
Whether to force re-download even if cached.
|
|
424
|
+
**kwargs
|
|
425
|
+
Additional arguments (currently unused).
|
|
426
|
+
|
|
427
|
+
Returns
|
|
428
|
+
-------
|
|
429
|
+
BaseConcatDataset
|
|
430
|
+
The loaded dataset.
|
|
431
|
+
|
|
432
|
+
Raises
|
|
433
|
+
------
|
|
434
|
+
ImportError
|
|
435
|
+
If huggingface-hub is not installed.
|
|
436
|
+
FileNotFoundError
|
|
437
|
+
If the repository or dataset files are not found.
|
|
438
|
+
|
|
439
|
+
Examples
|
|
440
|
+
--------
|
|
441
|
+
>>> from braindecode.datasets import BaseConcatDataset
|
|
442
|
+
>>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
|
|
443
|
+
>>> print(f"Loaded {len(dataset)} windows")
|
|
444
|
+
>>>
|
|
445
|
+
>>> # Use with PyTorch
|
|
446
|
+
>>> from torch.utils.data import DataLoader
|
|
447
|
+
>>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
448
|
+
"""
|
|
449
|
+
if zarr is False or huggingface_hub is False:
|
|
450
|
+
raise ImportError(
|
|
451
|
+
"huggingface hub functionality is not installed. Install with: "
|
|
452
|
+
"pip install braindecode[hub]"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
log.info(f"Loading dataset from Hugging Face Hub ({repo_id})...")
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
# Download the entire dataset directory
|
|
459
|
+
dataset_dir = huggingface_hub.snapshot_download(
|
|
460
|
+
repo_id=repo_id,
|
|
461
|
+
repo_type="dataset",
|
|
462
|
+
token=token,
|
|
463
|
+
cache_dir=cache_dir,
|
|
464
|
+
force_download=force_download,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Load format info
|
|
468
|
+
format_info_path = Path(dataset_dir) / "format_info.json"
|
|
469
|
+
if format_info_path.exists():
|
|
470
|
+
with open(format_info_path, "r") as f:
|
|
471
|
+
format_info = json.load(f)
|
|
472
|
+
|
|
473
|
+
# Verify it's zarr format
|
|
474
|
+
if format_info.get("format") != "zarr":
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"Dataset format is '{format_info.get('format')}', but only "
|
|
477
|
+
"'zarr' format is supported. Please re-upload the dataset."
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
format_info = {}
|
|
481
|
+
|
|
482
|
+
pipeline_name = format_info.get("pipeline_name", "braindecode")
|
|
483
|
+
|
|
484
|
+
# Find zarr dataset path (try sourcedata, derivatives, then root)
|
|
485
|
+
zarr_path = (
|
|
486
|
+
Path(dataset_dir) / "sourcedata" / pipeline_name / "dataset.zarr"
|
|
487
|
+
)
|
|
488
|
+
if not zarr_path.exists():
|
|
489
|
+
zarr_path = (
|
|
490
|
+
Path(dataset_dir) / "derivatives" / pipeline_name / "dataset.zarr"
|
|
491
|
+
)
|
|
492
|
+
if not zarr_path.exists():
|
|
493
|
+
zarr_path = Path(dataset_dir) / "dataset.zarr"
|
|
494
|
+
|
|
495
|
+
if not zarr_path.exists():
|
|
496
|
+
raise FileNotFoundError(
|
|
497
|
+
f"Zarr dataset not found at {zarr_path}. "
|
|
498
|
+
"The dataset may be in an unsupported format."
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
dataset = cls._load_from_zarr_inline(zarr_path, preload)
|
|
502
|
+
|
|
503
|
+
# Load BIDS metadata if available
|
|
504
|
+
cls._load_bids_metadata(dataset, Path(dataset_dir), pipeline_name)
|
|
505
|
+
|
|
506
|
+
log.info(f"Dataset loaded successfully from {repo_id}")
|
|
507
|
+
log.info(f"Recordings: {len(dataset.datasets)}")
|
|
508
|
+
log.info(
|
|
509
|
+
f"Total windows/samples: {format_info.get('total_samples', 'N/A')}"
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return dataset
|
|
513
|
+
|
|
514
|
+
except huggingface_hub.utils.HfHubHTTPError as e:
|
|
515
|
+
if e.response.status_code == 404:
|
|
516
|
+
raise FileNotFoundError(
|
|
517
|
+
f"Dataset '{repo_id}' not found on Hugging Face Hub. "
|
|
518
|
+
"Please check the repository ID and ensure it exists."
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
raise RuntimeError(f"Failed to download dataset: {e}")
|
|
522
|
+
except Exception as e:
|
|
523
|
+
raise RuntimeError(f"Failed to load dataset from Hub: {e}")
|
|
524
|
+
|
|
525
|
+
@classmethod
|
|
526
|
+
def _load_bids_metadata(
|
|
527
|
+
cls,
|
|
528
|
+
dataset,
|
|
529
|
+
dataset_dir: Path,
|
|
530
|
+
pipeline_name: str,
|
|
531
|
+
) -> None:
|
|
532
|
+
"""Load BIDS metadata from sidecar files and attach to dataset.
|
|
533
|
+
|
|
534
|
+
Parameters
|
|
535
|
+
----------
|
|
536
|
+
dataset : BaseConcatDataset
|
|
537
|
+
The loaded dataset to attach metadata to.
|
|
538
|
+
dataset_dir : Path
|
|
539
|
+
Root directory of the downloaded dataset.
|
|
540
|
+
pipeline_name : str
|
|
541
|
+
Name of the processing pipeline.
|
|
542
|
+
"""
|
|
543
|
+
# Try sourcedata first, fall back to derivatives for backwards compatibility
|
|
544
|
+
sourcedata_dir = dataset_dir / "sourcedata" / pipeline_name
|
|
545
|
+
if not sourcedata_dir.exists():
|
|
546
|
+
sourcedata_dir = dataset_dir / "derivatives" / pipeline_name
|
|
547
|
+
|
|
548
|
+
# Load participants.tsv if available
|
|
549
|
+
participants_path = sourcedata_dir / "participants.tsv"
|
|
550
|
+
if participants_path.exists():
|
|
551
|
+
try:
|
|
552
|
+
participants_df = pd.read_csv(participants_path, sep="\t")
|
|
553
|
+
# Store as attribute on the concat dataset
|
|
554
|
+
dataset.participants = participants_df
|
|
555
|
+
log.debug(
|
|
556
|
+
f"Loaded participants info for {len(participants_df)} subjects"
|
|
557
|
+
)
|
|
558
|
+
except Exception as e:
|
|
559
|
+
log.warning(f"Failed to load participants.tsv: {e}")
|
|
560
|
+
|
|
561
|
+
# Create layout for path generation
|
|
562
|
+
bids_layout = hub_format.BIDSSourcedataLayout(
|
|
563
|
+
dataset_dir, pipeline_name=pipeline_name
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
# Try to load events.tsv files and attach to individual datasets
|
|
567
|
+
for i_ds, ds in enumerate(dataset.datasets):
|
|
568
|
+
description = ds.description if ds.description is not None else pd.Series()
|
|
569
|
+
|
|
570
|
+
# Get BIDSPath for this recording
|
|
571
|
+
bids_path = bids_layout.get_bids_path(description)
|
|
572
|
+
|
|
573
|
+
# Load events.tsv if available
|
|
574
|
+
events_path = bids_path.copy().update(suffix="events", extension=".tsv")
|
|
575
|
+
if events_path.fpath.exists():
|
|
576
|
+
try:
|
|
577
|
+
events_df = pd.read_csv(events_path.fpath, sep="\t")
|
|
578
|
+
ds.bids_events = events_df
|
|
579
|
+
log.debug(f"Loaded events for recording {i_ds}")
|
|
580
|
+
except Exception as e:
|
|
581
|
+
log.warning(f"Failed to load events for recording {i_ds}: {e}")
|
|
582
|
+
|
|
583
|
+
# Load channels.tsv if available
|
|
584
|
+
channels_path = bids_path.copy().update(suffix="channels", extension=".tsv")
|
|
585
|
+
if channels_path.fpath.exists():
|
|
586
|
+
try:
|
|
587
|
+
channels_df = pd.read_csv(channels_path.fpath, sep="\t")
|
|
588
|
+
ds.bids_channels = channels_df
|
|
589
|
+
log.debug(f"Loaded channels for recording {i_ds}")
|
|
590
|
+
except Exception as e:
|
|
591
|
+
log.warning(f"Failed to load channels for recording {i_ds}: {e}")
|
|
592
|
+
|
|
593
|
+
def _convert_to_zarr_inline(
|
|
594
|
+
self,
|
|
595
|
+
output_path: Path,
|
|
596
|
+
compression: str,
|
|
597
|
+
compression_level: int,
|
|
598
|
+
) -> None:
|
|
599
|
+
"""Convert dataset to Zarr format (inline implementation)."""
|
|
600
|
+
|
|
601
|
+
if zarr is False or huggingface_hub is False:
|
|
602
|
+
raise ImportError(
|
|
603
|
+
"huggingface hub functionality is not installed. Install with: "
|
|
604
|
+
"pip install braindecode[hub]"
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if output_path.exists():
|
|
608
|
+
raise FileExistsError(
|
|
609
|
+
f"{output_path} already exists. Set overwrite=True to replace it."
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Create zarr store (zarr v3 API)
|
|
613
|
+
root = zarr.open(str(output_path), mode="w")
|
|
614
|
+
|
|
615
|
+
# Validate uniformity across all datasets using shared validation
|
|
616
|
+
dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
|
|
617
|
+
|
|
618
|
+
# Keep reference to first dataset for preprocessing kwargs
|
|
619
|
+
first_ds = self.datasets[0]
|
|
620
|
+
|
|
621
|
+
# Store global metadata
|
|
622
|
+
root.attrs["n_datasets"] = len(self.datasets)
|
|
623
|
+
root.attrs["dataset_type"] = dataset_type
|
|
624
|
+
root.attrs["braindecode_version"] = braindecode.__version__
|
|
625
|
+
|
|
626
|
+
# Track dependency versions for reproducibility
|
|
627
|
+
root.attrs["mne_version"] = mne.__version__
|
|
628
|
+
root.attrs["numpy_version"] = np.__version__
|
|
629
|
+
root.attrs["pandas_version"] = pd.__version__
|
|
630
|
+
root.attrs["zarr_version"] = zarr.__version__
|
|
631
|
+
root.attrs["scipy_version"] = scipy.__version__
|
|
632
|
+
|
|
633
|
+
# Save preprocessing kwargs (check first dataset, assuming uniform preprocessing)
|
|
634
|
+
# These are typically set by windowing functions on individual datasets
|
|
635
|
+
for kwarg_name in [
|
|
636
|
+
"raw_preproc_kwargs",
|
|
637
|
+
"window_kwargs",
|
|
638
|
+
"window_preproc_kwargs",
|
|
639
|
+
]:
|
|
640
|
+
# Check first dataset for these attributes
|
|
641
|
+
if hasattr(first_ds, kwarg_name):
|
|
642
|
+
kwargs = getattr(first_ds, kwarg_name)
|
|
643
|
+
if kwargs:
|
|
644
|
+
root.attrs[kwarg_name] = json.dumps(kwargs)
|
|
645
|
+
|
|
646
|
+
# Create compressor
|
|
647
|
+
compressor = _create_compressor(compression, compression_level)
|
|
648
|
+
|
|
649
|
+
# Save each recording
|
|
650
|
+
for i_ds, ds in enumerate(self.datasets):
|
|
651
|
+
grp = root.create_group(f"recording_{i_ds}")
|
|
652
|
+
|
|
653
|
+
if dataset_type == "WindowsDataset":
|
|
654
|
+
# Extract data from WindowsDataset
|
|
655
|
+
data = ds.windows.get_data()
|
|
656
|
+
metadata = ds.windows.metadata
|
|
657
|
+
description = ds.description
|
|
658
|
+
info_dict = ds.windows.info.to_json_dict()
|
|
659
|
+
target_name = ds.target_name if hasattr(ds, "target_name") else None
|
|
660
|
+
|
|
661
|
+
# Save using inlined function
|
|
662
|
+
_save_windows_to_zarr(
|
|
663
|
+
grp, data, metadata, description, info_dict, compressor, target_name
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
elif dataset_type == "EEGWindowsDataset":
|
|
667
|
+
# Get continuous raw data and metadata from EEGWindowsDataset
|
|
668
|
+
raw = ds.raw
|
|
669
|
+
metadata = ds.metadata
|
|
670
|
+
description = ds.description
|
|
671
|
+
info_dict = ds.raw.info.to_json_dict()
|
|
672
|
+
targets_from = ds.targets_from
|
|
673
|
+
last_target_only = ds.last_target_only
|
|
674
|
+
|
|
675
|
+
# Save using inlined function (saves continuous raw directly)
|
|
676
|
+
_save_eegwindows_to_zarr(
|
|
677
|
+
grp,
|
|
678
|
+
raw,
|
|
679
|
+
metadata,
|
|
680
|
+
description,
|
|
681
|
+
info_dict,
|
|
682
|
+
targets_from,
|
|
683
|
+
last_target_only,
|
|
684
|
+
compressor,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
elif dataset_type == "RawDataset":
|
|
688
|
+
# Get continuous raw data from RawDataset
|
|
689
|
+
raw = ds.raw
|
|
690
|
+
description = ds.description
|
|
691
|
+
info_dict = ds.raw.info.to_json_dict()
|
|
692
|
+
target_name = ds.target_name if hasattr(ds, "target_name") else None
|
|
693
|
+
|
|
694
|
+
# Save using inlined function
|
|
695
|
+
_save_raw_to_zarr(
|
|
696
|
+
grp, raw, description, info_dict, target_name, compressor
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
def _get_format_info_inline(self):
|
|
700
|
+
"""Get format info (inline implementation).
|
|
701
|
+
|
|
702
|
+
This is an inline version of hub_formats.get_format_info() that avoids
|
|
703
|
+
circular import.
|
|
704
|
+
"""
|
|
705
|
+
if len(self.datasets) == 0:
|
|
706
|
+
raise ValueError("Cannot get format info for empty dataset")
|
|
707
|
+
|
|
708
|
+
# Validate uniformity across all datasets using shared validation
|
|
709
|
+
dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
|
|
710
|
+
|
|
711
|
+
# Calculate dataset size
|
|
712
|
+
# BaseConcatDataset's __len__ already sums len(ds) for all datasets
|
|
713
|
+
total_samples = len(self)
|
|
714
|
+
total_size_mb = 0
|
|
715
|
+
|
|
716
|
+
for ds in self.datasets:
|
|
717
|
+
if dataset_type == "WindowsDataset":
|
|
718
|
+
# Use MNE's internal _size property to avoid loading data
|
|
719
|
+
total_size_mb += ds.windows._size / (1024 * 1024)
|
|
720
|
+
elif dataset_type == "EEGWindowsDataset":
|
|
721
|
+
# Use raw object's size (not extracted windows)
|
|
722
|
+
total_size_mb += ds.raw._size / (1024 * 1024)
|
|
723
|
+
elif dataset_type == "RawDataset":
|
|
724
|
+
total_size_mb += ds.raw._size / (1024 * 1024)
|
|
725
|
+
|
|
726
|
+
n_recordings = len(self.datasets)
|
|
727
|
+
|
|
728
|
+
return {
|
|
729
|
+
"n_recordings": n_recordings,
|
|
730
|
+
"total_samples": total_samples,
|
|
731
|
+
"total_size_mb": round(total_size_mb, 2),
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
@staticmethod
|
|
735
|
+
def _load_from_zarr_inline(input_path: Path, preload: bool):
|
|
736
|
+
"""Load dataset from Zarr format (inline implementation).
|
|
737
|
+
|
|
738
|
+
This is an inline version of hub_formats.load_from_zarr() that avoids
|
|
739
|
+
circular import by using hub_formats_core directly.
|
|
740
|
+
"""
|
|
741
|
+
if not input_path.exists():
|
|
742
|
+
raise FileNotFoundError(f"{input_path} does not exist.")
|
|
743
|
+
|
|
744
|
+
# Open zarr store (zarr v3 API)
|
|
745
|
+
root = zarr.open(str(input_path), mode="r")
|
|
746
|
+
|
|
747
|
+
n_datasets = root.attrs["n_datasets"]
|
|
748
|
+
dataset_type = root.attrs["dataset_type"]
|
|
749
|
+
|
|
750
|
+
# Get dataset classes from registry
|
|
751
|
+
WindowsDataset = get_dataset_class("WindowsDataset")
|
|
752
|
+
EEGWindowsDataset = get_dataset_class("EEGWindowsDataset")
|
|
753
|
+
RawDataset = get_dataset_class("RawDataset")
|
|
754
|
+
BaseConcatDataset = get_dataset_class("BaseConcatDataset")
|
|
755
|
+
|
|
756
|
+
datasets = []
|
|
757
|
+
for i_ds in range(n_datasets):
|
|
758
|
+
grp = root[f"recording_{i_ds}"]
|
|
759
|
+
|
|
760
|
+
if dataset_type == "WindowsDataset":
|
|
761
|
+
# Load using inlined function
|
|
762
|
+
data, metadata, description, info_dict, target_name = (
|
|
763
|
+
_load_windows_from_zarr(grp, preload)
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
# Convert to MNE objects and create dataset
|
|
767
|
+
info = Info.from_json_dict(info_dict)
|
|
768
|
+
events = np.column_stack(
|
|
769
|
+
[
|
|
770
|
+
metadata["i_start_in_trial"].values,
|
|
771
|
+
np.zeros(len(metadata), dtype=int),
|
|
772
|
+
metadata["target"].values,
|
|
773
|
+
]
|
|
774
|
+
)
|
|
775
|
+
epochs = mne.EpochsArray(data, info, events=events, metadata=metadata)
|
|
776
|
+
ds = WindowsDataset(epochs, description)
|
|
777
|
+
if target_name is not None:
|
|
778
|
+
ds.target_name = target_name
|
|
779
|
+
|
|
780
|
+
elif dataset_type == "EEGWindowsDataset":
|
|
781
|
+
# Load using inlined function
|
|
782
|
+
(
|
|
783
|
+
data,
|
|
784
|
+
metadata,
|
|
785
|
+
description,
|
|
786
|
+
info_dict,
|
|
787
|
+
targets_from,
|
|
788
|
+
last_target_only,
|
|
789
|
+
) = _load_eegwindows_from_zarr(grp, preload)
|
|
790
|
+
|
|
791
|
+
# Convert to MNE objects and create dataset
|
|
792
|
+
# Data is already in continuous format [n_channels, n_timepoints]
|
|
793
|
+
info = Info.from_json_dict(info_dict)
|
|
794
|
+
raw = mne.io.RawArray(data, info)
|
|
795
|
+
ds = EEGWindowsDataset(
|
|
796
|
+
raw=raw,
|
|
797
|
+
metadata=metadata,
|
|
798
|
+
description=description,
|
|
799
|
+
targets_from=targets_from,
|
|
800
|
+
last_target_only=last_target_only,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
elif dataset_type == "RawDataset":
|
|
804
|
+
# Load using inlined function
|
|
805
|
+
data, description, info_dict, target_name = _load_raw_from_zarr(
|
|
806
|
+
grp, preload
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
# Convert to MNE objects and create dataset
|
|
810
|
+
# Data is in continuous format [n_channels, n_timepoints]
|
|
811
|
+
info = Info.from_json_dict(info_dict)
|
|
812
|
+
raw = mne.io.RawArray(data, info)
|
|
813
|
+
ds = RawDataset(raw, description)
|
|
814
|
+
if target_name is not None:
|
|
815
|
+
ds.target_name = target_name
|
|
816
|
+
|
|
817
|
+
else:
|
|
818
|
+
raise ValueError(f"Unsupported dataset_type: {dataset_type}")
|
|
819
|
+
|
|
820
|
+
datasets.append(ds)
|
|
821
|
+
|
|
822
|
+
# Create concat dataset
|
|
823
|
+
concat_ds = BaseConcatDataset(datasets)
|
|
824
|
+
|
|
825
|
+
# Restore preprocessing kwargs (set on individual datasets, not concat)
|
|
826
|
+
for kwarg_name in [
|
|
827
|
+
"raw_preproc_kwargs",
|
|
828
|
+
"window_kwargs",
|
|
829
|
+
"window_preproc_kwargs",
|
|
830
|
+
]:
|
|
831
|
+
if kwarg_name in root.attrs:
|
|
832
|
+
kwargs = json.loads(root.attrs[kwarg_name])
|
|
833
|
+
# Set on each individual dataset (where they were originally stored)
|
|
834
|
+
for ds in datasets:
|
|
835
|
+
setattr(ds, kwarg_name, kwargs)
|
|
836
|
+
|
|
837
|
+
return concat_ds
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def _generate_readme_content(
|
|
841
|
+
format_info,
|
|
842
|
+
n_recordings: int,
|
|
843
|
+
n_channels: int,
|
|
844
|
+
sfreq,
|
|
845
|
+
data_type: str,
|
|
846
|
+
n_windows: int,
|
|
847
|
+
total_duration: float | None = None,
|
|
848
|
+
format: str = "zarr",
|
|
849
|
+
):
|
|
850
|
+
"""Generate README.md content for a dataset uploaded to the Hub.
|
|
851
|
+
|
|
852
|
+
Parameters
|
|
853
|
+
----------
|
|
854
|
+
format_info : dict
|
|
855
|
+
Dictionary containing format metadata (e.g., total_size_mb).
|
|
856
|
+
n_recordings : int
|
|
857
|
+
Number of recordings in the dataset.
|
|
858
|
+
n_channels : int
|
|
859
|
+
Number of EEG channels.
|
|
860
|
+
sfreq : float or None
|
|
861
|
+
Sampling frequency in Hz.
|
|
862
|
+
data_type : str
|
|
863
|
+
Type of dataset (e.g., "Windowed", "Continuous").
|
|
864
|
+
n_windows : int
|
|
865
|
+
Number of windows/samples in the dataset.
|
|
866
|
+
total_duration : float or None
|
|
867
|
+
Total duration in seconds across all recordings.
|
|
868
|
+
format : str
|
|
869
|
+
Storage format (default: "zarr").
|
|
870
|
+
|
|
871
|
+
Returns
|
|
872
|
+
-------
|
|
873
|
+
str
|
|
874
|
+
Markdown content for the README.md file.
|
|
875
|
+
"""
|
|
876
|
+
total_size_mb = (
|
|
877
|
+
format_info.get("total_size_mb", 0.0) if isinstance(format_info, dict) else 0.0
|
|
878
|
+
)
|
|
879
|
+
sfreq_str = f"{sfreq:g}" if sfreq is not None else "N/A"
|
|
880
|
+
|
|
881
|
+
duration_str = (
|
|
882
|
+
str(timedelta(seconds=int(total_duration))) if total_duration else "N/A"
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
return f"""---
|
|
886
|
+
tags:
|
|
887
|
+
- braindecode
|
|
888
|
+
- eeg
|
|
889
|
+
- neuroscience
|
|
890
|
+
- brain-computer-interface
|
|
891
|
+
- deep-learning
|
|
892
|
+
license: unknown
|
|
893
|
+
---
|
|
894
|
+
|
|
895
|
+
# EEG Dataset
|
|
896
|
+
|
|
897
|
+
This dataset was created using [braindecode](https://braindecode.org), a deep
|
|
898
|
+
learning library for EEG/MEG/ECoG signals.
|
|
899
|
+
|
|
900
|
+
## Dataset Information
|
|
901
|
+
|
|
902
|
+
| Property | Value |
|
|
903
|
+
|----------|------:|
|
|
904
|
+
| Recordings | {n_recordings} |
|
|
905
|
+
| Type | {data_type} |
|
|
906
|
+
| Channels | {n_channels} |
|
|
907
|
+
| Sampling frequency | {sfreq_str} Hz |
|
|
908
|
+
| Total duration | {duration_str} |
|
|
909
|
+
| Windows/samples | {n_windows:,} |
|
|
910
|
+
| Size | {total_size_mb:.2f} MB |
|
|
911
|
+
| Format | {format} |
|
|
912
|
+
|
|
913
|
+
## Quick Start
|
|
914
|
+
|
|
915
|
+
```python
|
|
916
|
+
from braindecode.datasets import BaseConcatDataset
|
|
917
|
+
|
|
918
|
+
# Load from Hugging Face Hub
|
|
919
|
+
dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
|
|
920
|
+
|
|
921
|
+
# Access a sample
|
|
922
|
+
X, y, metainfo = dataset[0]
|
|
923
|
+
# X: EEG data [n_channels, n_times]
|
|
924
|
+
# y: target label
|
|
925
|
+
# metainfo: window indices
|
|
926
|
+
```
|
|
927
|
+
|
|
928
|
+
## Training with PyTorch
|
|
929
|
+
|
|
930
|
+
```python
|
|
931
|
+
from torch.utils.data import DataLoader
|
|
932
|
+
|
|
933
|
+
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
|
|
934
|
+
|
|
935
|
+
for X, y, metainfo in loader:
|
|
936
|
+
# X: [batch_size, n_channels, n_times]
|
|
937
|
+
# y: [batch_size]
|
|
938
|
+
pass # Your training code
|
|
939
|
+
```
|
|
940
|
+
|
|
941
|
+
## BIDS-inspired Structure
|
|
942
|
+
|
|
943
|
+
This dataset uses a **BIDS-inspired** organization. Metadata files follow BIDS
|
|
944
|
+
conventions, while data is stored in Zarr format for efficient deep learning.
|
|
945
|
+
|
|
946
|
+
**BIDS-style metadata:**
|
|
947
|
+
- `dataset_description.json` - Dataset information
|
|
948
|
+
- `participants.tsv` - Subject metadata
|
|
949
|
+
- `*_events.tsv` - Trial/window events
|
|
950
|
+
- `*_channels.tsv` - Channel information
|
|
951
|
+
- `*_eeg.json` - Recording parameters
|
|
952
|
+
|
|
953
|
+
**Data storage:**
|
|
954
|
+
- `dataset.zarr/` - Zarr format (optimized for random access)
|
|
955
|
+
|
|
956
|
+
```
|
|
957
|
+
sourcedata/braindecode/
|
|
958
|
+
├── dataset_description.json
|
|
959
|
+
├── participants.tsv
|
|
960
|
+
├── dataset.zarr/
|
|
961
|
+
└── sub-<label>/
|
|
962
|
+
└── eeg/
|
|
963
|
+
├── *_events.tsv
|
|
964
|
+
├── *_channels.tsv
|
|
965
|
+
└── *_eeg.json
|
|
966
|
+
```
|
|
967
|
+
|
|
968
|
+
### Accessing Metadata
|
|
969
|
+
|
|
970
|
+
```python
|
|
971
|
+
# Participants info
|
|
972
|
+
if hasattr(dataset, "participants"):
|
|
973
|
+
print(dataset.participants)
|
|
974
|
+
|
|
975
|
+
# Events for a recording
|
|
976
|
+
if hasattr(dataset.datasets[0], "bids_events"):
|
|
977
|
+
print(dataset.datasets[0].bids_events)
|
|
978
|
+
|
|
979
|
+
# Channel info
|
|
980
|
+
if hasattr(dataset.datasets[0], "bids_channels"):
|
|
981
|
+
print(dataset.datasets[0].bids_channels)
|
|
982
|
+
```
|
|
983
|
+
|
|
984
|
+
---
|
|
985
|
+
|
|
986
|
+
*Created with [braindecode](https://braindecode.org)*
|
|
987
|
+
"""
|