braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -1,36 +1,19 @@
|
|
|
1
|
-
# mypy: ignore-errors
|
|
2
1
|
"""
|
|
3
2
|
Hugging Face Hub integration for EEG datasets.
|
|
4
3
|
|
|
5
4
|
This module provides push_to_hub() and pull_from_hub() functionality
|
|
6
5
|
for braindecode datasets, similar to the model Hub integration.
|
|
7
|
-
|
|
8
|
-
.. warning::
|
|
9
|
-
The format is **BIDS-inspired**, not **BIDS-compliant**. The metadata
|
|
10
|
-
files are BIDS-compliant, but the data is stored in Zarr format for
|
|
11
|
-
efficient training, which is not a valid BIDS EEG format.
|
|
12
|
-
|
|
13
|
-
The format follows a BIDS-inspired sourcedata structure:
|
|
14
|
-
- sourcedata/braindecode/
|
|
15
|
-
- dataset_description.json (BIDS-compliant)
|
|
16
|
-
- participants.tsv (BIDS-compliant)
|
|
17
|
-
- dataset.zarr/ (NOT BIDS-compliant - efficient data store)
|
|
18
|
-
- sub-<label>/
|
|
19
|
-
- eeg/
|
|
20
|
-
- *_events.tsv (BIDS-compliant)
|
|
21
|
-
- *_channels.tsv (BIDS-compliant)
|
|
22
|
-
- *_eeg.json (BIDS-compliant)
|
|
23
6
|
"""
|
|
24
7
|
|
|
25
8
|
# Authors: Kuntal Kokate
|
|
26
|
-
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
27
9
|
#
|
|
28
10
|
# License: BSD (3-clause)
|
|
29
11
|
|
|
12
|
+
import io
|
|
30
13
|
import json
|
|
31
14
|
import logging
|
|
32
15
|
import tempfile
|
|
33
|
-
|
|
16
|
+
import warnings
|
|
34
17
|
from pathlib import Path
|
|
35
18
|
from typing import TYPE_CHECKING, List, Optional, Union
|
|
36
19
|
|
|
@@ -38,28 +21,28 @@ import mne
|
|
|
38
21
|
import numpy as np
|
|
39
22
|
import pandas as pd
|
|
40
23
|
import scipy
|
|
41
|
-
from mne._fiff.meas_info import Info
|
|
42
24
|
from mne.utils import _soft_import
|
|
43
25
|
|
|
26
|
+
# TODO: Simplify this logic in the future with zarr v3+ only
|
|
27
|
+
# Optional imports for Hub functionality
|
|
28
|
+
try:
|
|
29
|
+
from numcodecs import Blosc, GZip, Zstd
|
|
30
|
+
|
|
31
|
+
NUMCODECS_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
NUMCODECS_AVAILABLE = False
|
|
34
|
+
Blosc = GZip = Zstd = None
|
|
35
|
+
|
|
44
36
|
if TYPE_CHECKING:
|
|
45
|
-
from
|
|
37
|
+
from .base import BaseDataset
|
|
46
38
|
|
|
47
39
|
import braindecode
|
|
48
40
|
|
|
41
|
+
# Import shared validation utilities
|
|
42
|
+
from . import hub_validation
|
|
43
|
+
|
|
49
44
|
# Import registry for dynamic class lookup (avoids circular imports)
|
|
50
|
-
from
|
|
51
|
-
|
|
52
|
-
# Hub format and validation utilities
|
|
53
|
-
from . import hub_format, hub_validation
|
|
54
|
-
from .hub_io import (
|
|
55
|
-
_create_compressor,
|
|
56
|
-
_load_eegwindows_from_zarr,
|
|
57
|
-
_load_raw_from_zarr,
|
|
58
|
-
_load_windows_from_zarr,
|
|
59
|
-
_save_eegwindows_to_zarr,
|
|
60
|
-
_save_raw_to_zarr,
|
|
61
|
-
_save_windows_to_zarr,
|
|
62
|
-
)
|
|
45
|
+
from .registry import get_dataset_class, get_dataset_type
|
|
63
46
|
|
|
64
47
|
# Lazy import zarr and huggingface_hub
|
|
65
48
|
zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
|
|
@@ -102,15 +85,13 @@ class HubDatasetMixin:
|
|
|
102
85
|
create_pr: bool = False,
|
|
103
86
|
compression: str = "blosc",
|
|
104
87
|
compression_level: int = 5,
|
|
105
|
-
pipeline_name: str = "braindecode",
|
|
106
88
|
) -> str:
|
|
107
89
|
"""
|
|
108
|
-
Upload the dataset to the Hugging Face Hub in
|
|
90
|
+
Upload the dataset to the Hugging Face Hub in Zarr format.
|
|
109
91
|
|
|
110
92
|
The dataset is converted to Zarr format with blosc compression, which provides
|
|
111
|
-
optimal random access performance for PyTorch training
|
|
112
|
-
|
|
113
|
-
and participants.tsv sidecar files.
|
|
93
|
+
optimal random access performance for PyTorch training (based on comprehensive
|
|
94
|
+
benchmarking).
|
|
114
95
|
|
|
115
96
|
Parameters
|
|
116
97
|
----------
|
|
@@ -128,8 +109,6 @@ class HubDatasetMixin:
|
|
|
128
109
|
Compression algorithm for Zarr. Options: "blosc", "zstd", "gzip", None.
|
|
129
110
|
compression_level : int, default=5
|
|
130
111
|
Compression level (0-9). Level 5 provides optimal balance.
|
|
131
|
-
pipeline_name : str, default="braindecode"
|
|
132
|
-
Name of the processing pipeline for BIDS sourcedata.
|
|
133
112
|
|
|
134
113
|
Returns
|
|
135
114
|
-------
|
|
@@ -146,11 +125,18 @@ class HubDatasetMixin:
|
|
|
146
125
|
Examples
|
|
147
126
|
--------
|
|
148
127
|
>>> dataset = NMT(path=path, preload=True)
|
|
149
|
-
>>> # Upload with
|
|
128
|
+
>>> # Upload with default settings (zarr with blosc compression)
|
|
150
129
|
>>> url = dataset.push_to_hub(
|
|
151
130
|
... repo_id="myusername/nmt-dataset",
|
|
152
131
|
... commit_message="Upload NMT EEG dataset"
|
|
153
132
|
... )
|
|
133
|
+
>>>
|
|
134
|
+
>>> # Or customize compression
|
|
135
|
+
>>> url = dataset.push_to_hub(
|
|
136
|
+
... repo_id="myusername/nmt-dataset",
|
|
137
|
+
... compression="blosc",
|
|
138
|
+
... compression_level=5
|
|
139
|
+
... )
|
|
154
140
|
"""
|
|
155
141
|
if huggingface_hub is False or zarr is False:
|
|
156
142
|
raise ImportError(
|
|
@@ -177,44 +163,25 @@ class HubDatasetMixin:
|
|
|
177
163
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
178
164
|
tmp_path = Path(tmpdir)
|
|
179
165
|
|
|
180
|
-
#
|
|
181
|
-
log.info("Creating BIDS-like sourcedata structure...")
|
|
182
|
-
bids_layout = hub_format.BIDSSourcedataLayout(
|
|
183
|
-
tmp_path, pipeline_name=pipeline_name
|
|
184
|
-
)
|
|
185
|
-
sourcedata_dir = bids_layout.create_structure()
|
|
186
|
-
|
|
187
|
-
# Save dataset_description.json
|
|
188
|
-
bids_layout.save_dataset_description()
|
|
189
|
-
|
|
190
|
-
# Save participants.tsv
|
|
191
|
-
descriptions = [ds.description for ds in self.datasets]
|
|
192
|
-
bids_layout.save_participants(descriptions)
|
|
193
|
-
|
|
194
|
-
# Save BIDS sidecar files for each recording
|
|
195
|
-
self._save_bids_sidecar_files(bids_layout)
|
|
196
|
-
|
|
197
|
-
# Convert dataset to Zarr format inside sourcedata
|
|
166
|
+
# Convert dataset to Zarr format
|
|
198
167
|
log.info("Converting dataset to Zarr format...")
|
|
199
|
-
dataset_path =
|
|
200
|
-
|
|
168
|
+
dataset_path = tmp_path / "dataset.zarr"
|
|
201
169
|
self._convert_to_zarr_inline(
|
|
202
170
|
dataset_path,
|
|
203
171
|
compression,
|
|
204
172
|
compression_level,
|
|
205
173
|
)
|
|
206
174
|
|
|
207
|
-
# Save dataset metadata
|
|
175
|
+
# Save dataset metadata
|
|
208
176
|
self._save_dataset_card(tmp_path)
|
|
209
177
|
|
|
210
178
|
# Save format info
|
|
211
179
|
format_info_path = tmp_path / "format_info.json"
|
|
212
|
-
with open(format_info_path, "w"
|
|
180
|
+
with open(format_info_path, "w") as f:
|
|
213
181
|
format_info = self._get_format_info_inline()
|
|
214
182
|
json.dump(
|
|
215
183
|
{
|
|
216
184
|
"format": "zarr",
|
|
217
|
-
"pipeline_name": pipeline_name,
|
|
218
185
|
"compression": compression,
|
|
219
186
|
"compression_level": compression_level,
|
|
220
187
|
"braindecode_version": braindecode.__version__,
|
|
@@ -227,8 +194,8 @@ class HubDatasetMixin:
|
|
|
227
194
|
# Default commit message
|
|
228
195
|
if commit_message is None:
|
|
229
196
|
commit_message = (
|
|
230
|
-
f"Upload EEG dataset in
|
|
231
|
-
f"
|
|
197
|
+
f"Upload EEG dataset in Zarr format "
|
|
198
|
+
f"({len(self.datasets)} recordings)"
|
|
232
199
|
)
|
|
233
200
|
|
|
234
201
|
# Upload folder to Hub
|
|
@@ -248,15 +215,13 @@ class HubDatasetMixin:
|
|
|
248
215
|
except Exception as e:
|
|
249
216
|
raise RuntimeError(f"Failed to upload dataset: {e}")
|
|
250
217
|
|
|
251
|
-
def _save_dataset_card(self, path: Path
|
|
218
|
+
def _save_dataset_card(self, path: Path) -> None:
|
|
252
219
|
"""Generate and save a dataset card (README.md) with metadata.
|
|
253
220
|
|
|
254
221
|
Parameters
|
|
255
222
|
----------
|
|
256
223
|
path : Path
|
|
257
224
|
Directory where README.md will be saved.
|
|
258
|
-
bids_inspired : bool
|
|
259
|
-
Whether to include BIDS-inspired format documentation.
|
|
260
225
|
"""
|
|
261
226
|
# Get info, which also validates uniformity across all datasets
|
|
262
227
|
format_info = self._get_format_info_inline()
|
|
@@ -269,27 +234,18 @@ class HubDatasetMixin:
|
|
|
269
234
|
|
|
270
235
|
n_windows = format_info["total_samples"]
|
|
271
236
|
|
|
272
|
-
# Compute total duration across all recordings
|
|
273
|
-
total_duration = 0.0
|
|
274
237
|
if dataset_type == "WindowsDataset":
|
|
275
238
|
n_channels = len(first_ds.windows.ch_names)
|
|
276
239
|
data_type = "Windowed (from Epochs object)"
|
|
277
240
|
sfreq = first_ds.windows.info["sfreq"]
|
|
278
|
-
for ds in self.datasets:
|
|
279
|
-
epoch_length = ds.windows.tmax - ds.windows.tmin
|
|
280
|
-
total_duration += len(ds.windows) * epoch_length
|
|
281
241
|
elif dataset_type == "EEGWindowsDataset":
|
|
282
242
|
n_channels = len(first_ds.raw.ch_names)
|
|
283
243
|
sfreq = first_ds.raw.info["sfreq"]
|
|
284
244
|
data_type = "Windowed (from Raw object)"
|
|
285
|
-
for ds in self.datasets:
|
|
286
|
-
total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
|
|
287
245
|
elif dataset_type == "RawDataset":
|
|
288
246
|
n_channels = len(first_ds.raw.ch_names)
|
|
289
247
|
sfreq = first_ds.raw.info["sfreq"]
|
|
290
248
|
data_type = "Continuous (Raw)"
|
|
291
|
-
for ds in self.datasets:
|
|
292
|
-
total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
|
|
293
249
|
else:
|
|
294
250
|
raise TypeError(f"Unsupported dataset type: {dataset_type}")
|
|
295
251
|
|
|
@@ -301,99 +257,13 @@ class HubDatasetMixin:
|
|
|
301
257
|
sfreq=sfreq,
|
|
302
258
|
data_type=data_type,
|
|
303
259
|
n_windows=n_windows,
|
|
304
|
-
total_duration=total_duration,
|
|
305
260
|
)
|
|
306
261
|
|
|
307
262
|
# Save README
|
|
308
263
|
readme_path = path / "README.md"
|
|
309
|
-
with open(readme_path, "w"
|
|
264
|
+
with open(readme_path, "w") as f:
|
|
310
265
|
f.write(readme_content)
|
|
311
266
|
|
|
312
|
-
def _save_bids_sidecar_files(
|
|
313
|
-
self, bids_layout: "hub_format.BIDSSourcedataLayout"
|
|
314
|
-
) -> None:
|
|
315
|
-
"""Save BIDS-compliant sidecar files for each recording.
|
|
316
|
-
|
|
317
|
-
This creates events.tsv, channels.tsv, and EEG sidecar JSON files
|
|
318
|
-
for each recording in a BIDS-like directory structure.
|
|
319
|
-
|
|
320
|
-
Parameters
|
|
321
|
-
----------
|
|
322
|
-
bids_layout : BIDSSourcedataLayout
|
|
323
|
-
BIDS layout object for path generation.
|
|
324
|
-
"""
|
|
325
|
-
dataset_type = get_dataset_type(self.datasets[0])
|
|
326
|
-
|
|
327
|
-
for i_ds, ds in enumerate(self.datasets):
|
|
328
|
-
# Get BIDS entities from description
|
|
329
|
-
description = ds.description if ds.description is not None else pd.Series()
|
|
330
|
-
|
|
331
|
-
# Get BIDSPath for this recording using mne_bids
|
|
332
|
-
bids_path = bids_layout.get_bids_path(description)
|
|
333
|
-
|
|
334
|
-
# Create subject directory
|
|
335
|
-
bids_path.mkdir(exist_ok=True)
|
|
336
|
-
|
|
337
|
-
# Get metadata and info based on dataset type
|
|
338
|
-
# Also compute recording_duration, recording_type, and epoch_length
|
|
339
|
-
recording_duration = None
|
|
340
|
-
recording_type = None
|
|
341
|
-
epoch_length = None
|
|
342
|
-
|
|
343
|
-
if dataset_type == "WindowsDataset":
|
|
344
|
-
info = ds.windows.info
|
|
345
|
-
metadata = ds.windows.metadata
|
|
346
|
-
sfreq = info["sfreq"]
|
|
347
|
-
# WindowsDataset contains pre-cut epochs
|
|
348
|
-
recording_type = "epoched"
|
|
349
|
-
# Use MNE's tmax - tmin for epoch length
|
|
350
|
-
epoch_length = ds.windows.tmax - ds.windows.tmin
|
|
351
|
-
# Total duration = number of epochs * epoch length
|
|
352
|
-
n_epochs = len(ds.windows)
|
|
353
|
-
recording_duration = n_epochs * epoch_length
|
|
354
|
-
elif dataset_type == "EEGWindowsDataset":
|
|
355
|
-
info = ds.raw.info
|
|
356
|
-
metadata = ds.metadata
|
|
357
|
-
sfreq = info["sfreq"]
|
|
358
|
-
# EEGWindowsDataset has continuous raw with window metadata
|
|
359
|
-
recording_type = "epoched"
|
|
360
|
-
# Use MNE Raw's duration property
|
|
361
|
-
recording_duration = ds.raw.duration
|
|
362
|
-
# Compute epoch_length from metadata if available
|
|
363
|
-
if metadata is not None and len(metadata) > 0:
|
|
364
|
-
i_start = metadata["i_start_in_trial"].iloc[0]
|
|
365
|
-
i_stop = metadata["i_stop_in_trial"].iloc[0]
|
|
366
|
-
epoch_length = (i_stop - i_start) / sfreq
|
|
367
|
-
elif dataset_type == "RawDataset":
|
|
368
|
-
info = ds.raw.info
|
|
369
|
-
metadata = None
|
|
370
|
-
sfreq = info["sfreq"]
|
|
371
|
-
# RawDataset is continuous
|
|
372
|
-
recording_type = "continuous"
|
|
373
|
-
# Use MNE Raw's duration property
|
|
374
|
-
recording_duration = ds.raw.duration
|
|
375
|
-
else:
|
|
376
|
-
continue
|
|
377
|
-
|
|
378
|
-
# Determine task name from description or BIDSPath
|
|
379
|
-
task_name = bids_path.task or "unknown"
|
|
380
|
-
|
|
381
|
-
# Save BIDS sidecar files using mne_bids BIDSPath
|
|
382
|
-
hub_format.save_bids_sidecar_files(
|
|
383
|
-
bids_path=bids_path,
|
|
384
|
-
info=info,
|
|
385
|
-
metadata=metadata,
|
|
386
|
-
sfreq=sfreq,
|
|
387
|
-
task_name=str(task_name),
|
|
388
|
-
recording_duration=recording_duration,
|
|
389
|
-
recording_type=recording_type,
|
|
390
|
-
epoch_length=epoch_length,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
log.debug(
|
|
394
|
-
f"Saved BIDS sidecar files for recording {i_ds} to {bids_path.directory}"
|
|
395
|
-
)
|
|
396
|
-
|
|
397
267
|
@classmethod
|
|
398
268
|
def pull_from_hub(
|
|
399
269
|
cls,
|
|
@@ -479,19 +349,8 @@ class HubDatasetMixin:
|
|
|
479
349
|
else:
|
|
480
350
|
format_info = {}
|
|
481
351
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
# Find zarr dataset path (try sourcedata, derivatives, then root)
|
|
485
|
-
zarr_path = (
|
|
486
|
-
Path(dataset_dir) / "sourcedata" / pipeline_name / "dataset.zarr"
|
|
487
|
-
)
|
|
488
|
-
if not zarr_path.exists():
|
|
489
|
-
zarr_path = (
|
|
490
|
-
Path(dataset_dir) / "derivatives" / pipeline_name / "dataset.zarr"
|
|
491
|
-
)
|
|
492
|
-
if not zarr_path.exists():
|
|
493
|
-
zarr_path = Path(dataset_dir) / "dataset.zarr"
|
|
494
|
-
|
|
352
|
+
# Load zarr dataset
|
|
353
|
+
zarr_path = Path(dataset_dir) / "dataset.zarr"
|
|
495
354
|
if not zarr_path.exists():
|
|
496
355
|
raise FileNotFoundError(
|
|
497
356
|
f"Zarr dataset not found at {zarr_path}. "
|
|
@@ -500,9 +359,6 @@ class HubDatasetMixin:
|
|
|
500
359
|
|
|
501
360
|
dataset = cls._load_from_zarr_inline(zarr_path, preload)
|
|
502
361
|
|
|
503
|
-
# Load BIDS metadata if available
|
|
504
|
-
cls._load_bids_metadata(dataset, Path(dataset_dir), pipeline_name)
|
|
505
|
-
|
|
506
362
|
log.info(f"Dataset loaded successfully from {repo_id}")
|
|
507
363
|
log.info(f"Recordings: {len(dataset.datasets)}")
|
|
508
364
|
log.info(
|
|
@@ -522,74 +378,6 @@ class HubDatasetMixin:
|
|
|
522
378
|
except Exception as e:
|
|
523
379
|
raise RuntimeError(f"Failed to load dataset from Hub: {e}")
|
|
524
380
|
|
|
525
|
-
@classmethod
|
|
526
|
-
def _load_bids_metadata(
|
|
527
|
-
cls,
|
|
528
|
-
dataset,
|
|
529
|
-
dataset_dir: Path,
|
|
530
|
-
pipeline_name: str,
|
|
531
|
-
) -> None:
|
|
532
|
-
"""Load BIDS metadata from sidecar files and attach to dataset.
|
|
533
|
-
|
|
534
|
-
Parameters
|
|
535
|
-
----------
|
|
536
|
-
dataset : BaseConcatDataset
|
|
537
|
-
The loaded dataset to attach metadata to.
|
|
538
|
-
dataset_dir : Path
|
|
539
|
-
Root directory of the downloaded dataset.
|
|
540
|
-
pipeline_name : str
|
|
541
|
-
Name of the processing pipeline.
|
|
542
|
-
"""
|
|
543
|
-
# Try sourcedata first, fall back to derivatives for backwards compatibility
|
|
544
|
-
sourcedata_dir = dataset_dir / "sourcedata" / pipeline_name
|
|
545
|
-
if not sourcedata_dir.exists():
|
|
546
|
-
sourcedata_dir = dataset_dir / "derivatives" / pipeline_name
|
|
547
|
-
|
|
548
|
-
# Load participants.tsv if available
|
|
549
|
-
participants_path = sourcedata_dir / "participants.tsv"
|
|
550
|
-
if participants_path.exists():
|
|
551
|
-
try:
|
|
552
|
-
participants_df = pd.read_csv(participants_path, sep="\t")
|
|
553
|
-
# Store as attribute on the concat dataset
|
|
554
|
-
dataset.participants = participants_df
|
|
555
|
-
log.debug(
|
|
556
|
-
f"Loaded participants info for {len(participants_df)} subjects"
|
|
557
|
-
)
|
|
558
|
-
except Exception as e:
|
|
559
|
-
log.warning(f"Failed to load participants.tsv: {e}")
|
|
560
|
-
|
|
561
|
-
# Create layout for path generation
|
|
562
|
-
bids_layout = hub_format.BIDSSourcedataLayout(
|
|
563
|
-
dataset_dir, pipeline_name=pipeline_name
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
# Try to load events.tsv files and attach to individual datasets
|
|
567
|
-
for i_ds, ds in enumerate(dataset.datasets):
|
|
568
|
-
description = ds.description if ds.description is not None else pd.Series()
|
|
569
|
-
|
|
570
|
-
# Get BIDSPath for this recording
|
|
571
|
-
bids_path = bids_layout.get_bids_path(description)
|
|
572
|
-
|
|
573
|
-
# Load events.tsv if available
|
|
574
|
-
events_path = bids_path.copy().update(suffix="events", extension=".tsv")
|
|
575
|
-
if events_path.fpath.exists():
|
|
576
|
-
try:
|
|
577
|
-
events_df = pd.read_csv(events_path.fpath, sep="\t")
|
|
578
|
-
ds.bids_events = events_df
|
|
579
|
-
log.debug(f"Loaded events for recording {i_ds}")
|
|
580
|
-
except Exception as e:
|
|
581
|
-
log.warning(f"Failed to load events for recording {i_ds}: {e}")
|
|
582
|
-
|
|
583
|
-
# Load channels.tsv if available
|
|
584
|
-
channels_path = bids_path.copy().update(suffix="channels", extension=".tsv")
|
|
585
|
-
if channels_path.fpath.exists():
|
|
586
|
-
try:
|
|
587
|
-
channels_df = pd.read_csv(channels_path.fpath, sep="\t")
|
|
588
|
-
ds.bids_channels = channels_df
|
|
589
|
-
log.debug(f"Loaded channels for recording {i_ds}")
|
|
590
|
-
except Exception as e:
|
|
591
|
-
log.warning(f"Failed to load channels for recording {i_ds}: {e}")
|
|
592
|
-
|
|
593
381
|
def _convert_to_zarr_inline(
|
|
594
382
|
self,
|
|
595
383
|
output_path: Path,
|
|
@@ -609,8 +397,9 @@ class HubDatasetMixin:
|
|
|
609
397
|
f"{output_path} already exists. Set overwrite=True to replace it."
|
|
610
398
|
)
|
|
611
399
|
|
|
612
|
-
# Create zarr store (zarr
|
|
613
|
-
|
|
400
|
+
# Create zarr store (zarr v2 API)
|
|
401
|
+
store = zarr.DirectoryStore(str(output_path))
|
|
402
|
+
root = zarr.group(store=store, overwrite=False)
|
|
614
403
|
|
|
615
404
|
# Validate uniformity across all datasets using shared validation
|
|
616
405
|
dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
|
|
@@ -655,7 +444,7 @@ class HubDatasetMixin:
|
|
|
655
444
|
data = ds.windows.get_data()
|
|
656
445
|
metadata = ds.windows.metadata
|
|
657
446
|
description = ds.description
|
|
658
|
-
info_dict = ds.windows.info
|
|
447
|
+
info_dict = _mne_info_to_dict(ds.windows.info)
|
|
659
448
|
target_name = ds.target_name if hasattr(ds, "target_name") else None
|
|
660
449
|
|
|
661
450
|
# Save using inlined function
|
|
@@ -668,7 +457,7 @@ class HubDatasetMixin:
|
|
|
668
457
|
raw = ds.raw
|
|
669
458
|
metadata = ds.metadata
|
|
670
459
|
description = ds.description
|
|
671
|
-
info_dict = ds.raw.info
|
|
460
|
+
info_dict = _mne_info_to_dict(ds.raw.info)
|
|
672
461
|
targets_from = ds.targets_from
|
|
673
462
|
last_target_only = ds.last_target_only
|
|
674
463
|
|
|
@@ -688,7 +477,7 @@ class HubDatasetMixin:
|
|
|
688
477
|
# Get continuous raw data from RawDataset
|
|
689
478
|
raw = ds.raw
|
|
690
479
|
description = ds.description
|
|
691
|
-
info_dict = ds.raw.info
|
|
480
|
+
info_dict = _mne_info_to_dict(ds.raw.info)
|
|
692
481
|
target_name = ds.target_name if hasattr(ds, "target_name") else None
|
|
693
482
|
|
|
694
483
|
# Save using inlined function
|
|
@@ -741,8 +530,9 @@ class HubDatasetMixin:
|
|
|
741
530
|
if not input_path.exists():
|
|
742
531
|
raise FileNotFoundError(f"{input_path} does not exist.")
|
|
743
532
|
|
|
744
|
-
# Open zarr store (zarr
|
|
745
|
-
|
|
533
|
+
# Open zarr store (zarr v2 API)
|
|
534
|
+
store = zarr.DirectoryStore(str(input_path))
|
|
535
|
+
root = zarr.group(store=store)
|
|
746
536
|
|
|
747
537
|
n_datasets = root.attrs["n_datasets"]
|
|
748
538
|
dataset_type = root.attrs["dataset_type"]
|
|
@@ -764,7 +554,7 @@ class HubDatasetMixin:
|
|
|
764
554
|
)
|
|
765
555
|
|
|
766
556
|
# Convert to MNE objects and create dataset
|
|
767
|
-
info =
|
|
557
|
+
info = _dict_to_mne_info(info_dict)
|
|
768
558
|
events = np.column_stack(
|
|
769
559
|
[
|
|
770
560
|
metadata["i_start_in_trial"].values,
|
|
@@ -790,7 +580,7 @@ class HubDatasetMixin:
|
|
|
790
580
|
|
|
791
581
|
# Convert to MNE objects and create dataset
|
|
792
582
|
# Data is already in continuous format [n_channels, n_timepoints]
|
|
793
|
-
info =
|
|
583
|
+
info = _dict_to_mne_info(info_dict)
|
|
794
584
|
raw = mne.io.RawArray(data, info)
|
|
795
585
|
ds = EEGWindowsDataset(
|
|
796
586
|
raw=raw,
|
|
@@ -808,7 +598,7 @@ class HubDatasetMixin:
|
|
|
808
598
|
|
|
809
599
|
# Convert to MNE objects and create dataset
|
|
810
600
|
# Data is in continuous format [n_channels, n_timepoints]
|
|
811
|
-
info =
|
|
601
|
+
info = _dict_to_mne_info(info_dict)
|
|
812
602
|
raw = mne.io.RawArray(data, info)
|
|
813
603
|
ds = RawDataset(raw, description)
|
|
814
604
|
if target_name is not None:
|
|
@@ -837,6 +627,251 @@ class HubDatasetMixin:
|
|
|
837
627
|
return concat_ds
|
|
838
628
|
|
|
839
629
|
|
|
630
|
+
# =============================================================================
|
|
631
|
+
# Core Zarr I/O Utilities
|
|
632
|
+
# =============================================================================
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
# TODO: remove when this MNE is solved https://github.com/mne-tools/mne-python/issues/13487
|
|
636
|
+
def _mne_info_to_dict(info):
|
|
637
|
+
"""Convert MNE Info object to dictionary for JSON serialization."""
|
|
638
|
+
return {
|
|
639
|
+
"ch_names": info["ch_names"],
|
|
640
|
+
"sfreq": float(info["sfreq"]),
|
|
641
|
+
"ch_types": [str(ch_type) for ch_type in info.get_channel_types()],
|
|
642
|
+
"lowpass": float(info["lowpass"]) if info["lowpass"] is not None else None,
|
|
643
|
+
"highpass": float(info["highpass"]) if info["highpass"] is not None else None,
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _dict_to_mne_info(info_dict):
|
|
648
|
+
"""Convert dictionary back to MNE Info object."""
|
|
649
|
+
info = mne.create_info(
|
|
650
|
+
ch_names=info_dict["ch_names"],
|
|
651
|
+
sfreq=info_dict["sfreq"],
|
|
652
|
+
ch_types=info_dict["ch_types"],
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
# Use _unlock() to set filter info when reconstructing from saved metadata
|
|
656
|
+
# This is necessary because MNE protects these fields to prevent users from
|
|
657
|
+
# setting filter parameters without actually filtering the data
|
|
658
|
+
with info._unlock():
|
|
659
|
+
if info_dict.get("lowpass") is not None:
|
|
660
|
+
info["lowpass"] = info_dict["lowpass"]
|
|
661
|
+
if info_dict.get("highpass") is not None:
|
|
662
|
+
info["highpass"] = info_dict["highpass"]
|
|
663
|
+
|
|
664
|
+
return info
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def _save_windows_to_zarr(
|
|
668
|
+
grp, data, metadata, description, info, compressor, target_name
|
|
669
|
+
):
|
|
670
|
+
"""Save windowed data to Zarr group (low-level function)."""
|
|
671
|
+
# Save data with chunking for random access
|
|
672
|
+
grp.create_dataset(
|
|
673
|
+
"data",
|
|
674
|
+
data=data.astype(np.float32),
|
|
675
|
+
chunks=(1, data.shape[1], data.shape[2]),
|
|
676
|
+
compressor=compressor,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Save metadata
|
|
680
|
+
metadata_json = metadata.to_json(orient="split", date_format="iso")
|
|
681
|
+
grp.attrs["metadata"] = metadata_json
|
|
682
|
+
# Save dtypes to preserve them across platforms (int32 vs int64, etc.)
|
|
683
|
+
metadata_dtypes = metadata.dtypes.apply(str).to_json()
|
|
684
|
+
grp.attrs["metadata_dtypes"] = metadata_dtypes
|
|
685
|
+
|
|
686
|
+
# Save description
|
|
687
|
+
description_json = description.to_json(date_format="iso")
|
|
688
|
+
grp.attrs["description"] = description_json
|
|
689
|
+
|
|
690
|
+
# Save MNE info
|
|
691
|
+
grp.attrs["info"] = json.dumps(info)
|
|
692
|
+
|
|
693
|
+
# Save target name if provided
|
|
694
|
+
if target_name is not None:
|
|
695
|
+
grp.attrs["target_name"] = target_name
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def _save_eegwindows_to_zarr(
|
|
699
|
+
grp, raw, metadata, description, info, targets_from, last_target_only, compressor
|
|
700
|
+
):
|
|
701
|
+
"""Save EEG continuous raw data to Zarr group (low-level function)."""
|
|
702
|
+
# Extract continuous data from Raw [n_channels, n_timepoints]
|
|
703
|
+
continuous_data = raw.get_data()
|
|
704
|
+
|
|
705
|
+
# Save continuous data with chunking optimized for window extraction
|
|
706
|
+
# Chunk size: all channels, 10000 timepoints for efficient random access
|
|
707
|
+
grp.create_dataset(
|
|
708
|
+
"data",
|
|
709
|
+
data=continuous_data.astype(np.float32),
|
|
710
|
+
chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
|
|
711
|
+
compressor=compressor,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Save metadata
|
|
715
|
+
metadata_json = metadata.to_json(orient="split", date_format="iso")
|
|
716
|
+
grp.attrs["metadata"] = metadata_json
|
|
717
|
+
# Save dtypes to preserve them across platforms (int32 vs int64, etc.)
|
|
718
|
+
metadata_dtypes = metadata.dtypes.apply(str).to_json()
|
|
719
|
+
grp.attrs["metadata_dtypes"] = metadata_dtypes
|
|
720
|
+
|
|
721
|
+
# Save description
|
|
722
|
+
description_json = description.to_json(date_format="iso")
|
|
723
|
+
grp.attrs["description"] = description_json
|
|
724
|
+
|
|
725
|
+
# Save MNE info
|
|
726
|
+
grp.attrs["info"] = json.dumps(info)
|
|
727
|
+
|
|
728
|
+
# Save EEGWindowsDataset-specific attributes
|
|
729
|
+
grp.attrs["targets_from"] = targets_from
|
|
730
|
+
grp.attrs["last_target_only"] = last_target_only
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _load_windows_from_zarr(grp, preload):
|
|
734
|
+
"""Load windowed data from Zarr group (low-level function)."""
|
|
735
|
+
# Load metadata
|
|
736
|
+
metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
|
|
737
|
+
# Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
|
|
738
|
+
dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
|
|
739
|
+
for col, dtype_str in dtypes_dict.items():
|
|
740
|
+
metadata[col] = metadata[col].astype(dtype_str)
|
|
741
|
+
|
|
742
|
+
# Load description
|
|
743
|
+
description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
|
|
744
|
+
|
|
745
|
+
# Load info
|
|
746
|
+
info_dict = json.loads(grp.attrs["info"])
|
|
747
|
+
|
|
748
|
+
# Load data
|
|
749
|
+
if preload:
|
|
750
|
+
data = grp["data"][:]
|
|
751
|
+
else:
|
|
752
|
+
data = grp["data"][:]
|
|
753
|
+
# TODO: Implement lazy loading properly
|
|
754
|
+
warnings.warn(
|
|
755
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
756
|
+
"Loading all data into memory.",
|
|
757
|
+
UserWarning,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Load target name
|
|
761
|
+
target_name = grp.attrs.get("target_name", None)
|
|
762
|
+
|
|
763
|
+
return data, metadata, description, info_dict, target_name
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _load_eegwindows_from_zarr(grp, preload):
|
|
767
|
+
"""Load EEG continuous raw data from Zarr group (low-level function)."""
|
|
768
|
+
# Load metadata
|
|
769
|
+
metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
|
|
770
|
+
# Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
|
|
771
|
+
dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
|
|
772
|
+
for col, dtype_str in dtypes_dict.items():
|
|
773
|
+
metadata[col] = metadata[col].astype(dtype_str)
|
|
774
|
+
|
|
775
|
+
# Load description
|
|
776
|
+
description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
|
|
777
|
+
|
|
778
|
+
# Load info
|
|
779
|
+
info_dict = json.loads(grp.attrs["info"])
|
|
780
|
+
|
|
781
|
+
# Load data
|
|
782
|
+
if preload:
|
|
783
|
+
data = grp["data"][:]
|
|
784
|
+
else:
|
|
785
|
+
data = grp["data"][:]
|
|
786
|
+
warnings.warn(
|
|
787
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
788
|
+
"Loading all data into memory.",
|
|
789
|
+
UserWarning,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# Load EEGWindowsDataset-specific attributes
|
|
793
|
+
targets_from = grp.attrs.get("targets_from", "metadata")
|
|
794
|
+
last_target_only = grp.attrs.get("last_target_only", True)
|
|
795
|
+
|
|
796
|
+
return data, metadata, description, info_dict, targets_from, last_target_only
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _save_raw_to_zarr(grp, raw, description, info, target_name, compressor):
|
|
800
|
+
"""Save RawDataset continuous raw data to Zarr group (low-level function)."""
|
|
801
|
+
# Extract continuous data from Raw [n_channels, n_timepoints]
|
|
802
|
+
continuous_data = raw.get_data()
|
|
803
|
+
|
|
804
|
+
# Save continuous data with chunking optimized for efficient access
|
|
805
|
+
# Chunk size: all channels, 10000 timepoints for efficient random access
|
|
806
|
+
grp.create_dataset(
|
|
807
|
+
"data",
|
|
808
|
+
data=continuous_data.astype(np.float32),
|
|
809
|
+
chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
|
|
810
|
+
compressor=compressor,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Save description
|
|
814
|
+
description_json = description.to_json(date_format="iso")
|
|
815
|
+
grp.attrs["description"] = description_json
|
|
816
|
+
|
|
817
|
+
# Save MNE info
|
|
818
|
+
grp.attrs["info"] = json.dumps(info)
|
|
819
|
+
|
|
820
|
+
# Save target name if provided
|
|
821
|
+
if target_name is not None:
|
|
822
|
+
grp.attrs["target_name"] = target_name
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def _load_raw_from_zarr(grp, preload):
|
|
826
|
+
"""Load RawDataset continuous raw data from Zarr group (low-level function)."""
|
|
827
|
+
# Load description
|
|
828
|
+
description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
|
|
829
|
+
|
|
830
|
+
# Load info
|
|
831
|
+
info_dict = json.loads(grp.attrs["info"])
|
|
832
|
+
|
|
833
|
+
# Load data
|
|
834
|
+
if preload:
|
|
835
|
+
data = grp["data"][:]
|
|
836
|
+
else:
|
|
837
|
+
data = grp["data"][:]
|
|
838
|
+
# TODO: Implement lazy loading properly
|
|
839
|
+
warnings.warn(
|
|
840
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
841
|
+
"Loading all data into memory.",
|
|
842
|
+
UserWarning,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
# Load target name
|
|
846
|
+
target_name = grp.attrs.get("target_name", None)
|
|
847
|
+
|
|
848
|
+
return data, description, info_dict, target_name
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _create_compressor(compression, compression_level):
|
|
852
|
+
"""Create a Zarr compressor object (zarr v2 API)."""
|
|
853
|
+
if zarr is False:
|
|
854
|
+
raise ImportError(
|
|
855
|
+
"Zarr is not installed. Install with: pip install braindecode[hub]"
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
if not NUMCODECS_AVAILABLE:
|
|
859
|
+
raise ImportError(
|
|
860
|
+
"numcodecs is not installed. Install with: pip install braindecode[hub]"
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
# Zarr v2 uses numcodecs compressors
|
|
864
|
+
if compression == "blosc":
|
|
865
|
+
return Blosc(cname="zstd", clevel=compression_level)
|
|
866
|
+
elif compression == "zstd":
|
|
867
|
+
return Zstd(level=compression_level)
|
|
868
|
+
elif compression == "gzip":
|
|
869
|
+
return GZip(level=compression_level)
|
|
870
|
+
else:
|
|
871
|
+
return None
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
# TODO: improve content
|
|
840
875
|
def _generate_readme_content(
|
|
841
876
|
format_info,
|
|
842
877
|
n_recordings: int,
|
|
@@ -844,144 +879,84 @@ def _generate_readme_content(
|
|
|
844
879
|
sfreq,
|
|
845
880
|
data_type: str,
|
|
846
881
|
n_windows: int,
|
|
847
|
-
total_duration: float | None = None,
|
|
848
882
|
format: str = "zarr",
|
|
849
883
|
):
|
|
850
|
-
"""Generate README.md content for a dataset uploaded to the Hub.
|
|
851
|
-
|
|
852
|
-
Parameters
|
|
853
|
-
----------
|
|
854
|
-
format_info : dict
|
|
855
|
-
Dictionary containing format metadata (e.g., total_size_mb).
|
|
856
|
-
n_recordings : int
|
|
857
|
-
Number of recordings in the dataset.
|
|
858
|
-
n_channels : int
|
|
859
|
-
Number of EEG channels.
|
|
860
|
-
sfreq : float or None
|
|
861
|
-
Sampling frequency in Hz.
|
|
862
|
-
data_type : str
|
|
863
|
-
Type of dataset (e.g., "Windowed", "Continuous").
|
|
864
|
-
n_windows : int
|
|
865
|
-
Number of windows/samples in the dataset.
|
|
866
|
-
total_duration : float or None
|
|
867
|
-
Total duration in seconds across all recordings.
|
|
868
|
-
format : str
|
|
869
|
-
Storage format (default: "zarr").
|
|
870
|
-
|
|
871
|
-
Returns
|
|
872
|
-
-------
|
|
873
|
-
str
|
|
874
|
-
Markdown content for the README.md file.
|
|
875
|
-
"""
|
|
884
|
+
"""Generate README.md content for a dataset uploaded to the Hub."""
|
|
885
|
+
# Use safe access for total size and format sfreq nicely
|
|
876
886
|
total_size_mb = (
|
|
877
887
|
format_info.get("total_size_mb", 0.0) if isinstance(format_info, dict) else 0.0
|
|
878
888
|
)
|
|
879
889
|
sfreq_str = f"{sfreq:g}" if sfreq is not None else "N/A"
|
|
880
890
|
|
|
881
|
-
duration_str = (
|
|
882
|
-
str(timedelta(seconds=int(total_duration))) if total_duration else "N/A"
|
|
883
|
-
)
|
|
884
|
-
|
|
885
891
|
return f"""---
|
|
886
892
|
tags:
|
|
887
893
|
- braindecode
|
|
888
894
|
- eeg
|
|
889
895
|
- neuroscience
|
|
890
896
|
- brain-computer-interface
|
|
891
|
-
- deep-learning
|
|
892
897
|
license: unknown
|
|
893
898
|
---
|
|
894
899
|
|
|
895
900
|
# EEG Dataset
|
|
896
901
|
|
|
897
|
-
This dataset was created using [braindecode](https://braindecode.org), a deep
|
|
898
|
-
learning library for EEG/MEG/ECoG signals.
|
|
902
|
+
This dataset was created using [braindecode](https://braindecode.org), a library for deep learning with EEG/MEG/ECoG signals.
|
|
899
903
|
|
|
900
904
|
## Dataset Information
|
|
901
905
|
|
|
902
906
|
| Property | Value |
|
|
903
|
-
|
|
904
|
-
|
|
|
905
|
-
|
|
|
906
|
-
|
|
|
907
|
+
|---|---:|
|
|
908
|
+
| Number of recordings | {n_recordings} |
|
|
909
|
+
| Dataset type | {data_type} |
|
|
910
|
+
| Number of channels | {n_channels} |
|
|
907
911
|
| Sampling frequency | {sfreq_str} Hz |
|
|
908
|
-
|
|
|
909
|
-
|
|
|
910
|
-
|
|
|
911
|
-
| Format | {format} |
|
|
912
|
-
|
|
913
|
-
## Quick Start
|
|
914
|
-
|
|
915
|
-
```python
|
|
916
|
-
from braindecode.datasets import BaseConcatDataset
|
|
917
|
-
|
|
918
|
-
# Load from Hugging Face Hub
|
|
919
|
-
dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
|
|
920
|
-
|
|
921
|
-
# Access a sample
|
|
922
|
-
X, y, metainfo = dataset[0]
|
|
923
|
-
# X: EEG data [n_channels, n_times]
|
|
924
|
-
# y: target label
|
|
925
|
-
# metainfo: window indices
|
|
926
|
-
```
|
|
927
|
-
|
|
928
|
-
## Training with PyTorch
|
|
929
|
-
|
|
930
|
-
```python
|
|
931
|
-
from torch.utils.data import DataLoader
|
|
932
|
-
|
|
933
|
-
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
|
|
934
|
-
|
|
935
|
-
for X, y, metainfo in loader:
|
|
936
|
-
# X: [batch_size, n_channels, n_times]
|
|
937
|
-
# y: [batch_size]
|
|
938
|
-
pass # Your training code
|
|
939
|
-
```
|
|
940
|
-
|
|
941
|
-
## BIDS-inspired Structure
|
|
942
|
-
|
|
943
|
-
This dataset uses a **BIDS-inspired** organization. Metadata files follow BIDS
|
|
944
|
-
conventions, while data is stored in Zarr format for efficient deep learning.
|
|
945
|
-
|
|
946
|
-
**BIDS-style metadata:**
|
|
947
|
-
- `dataset_description.json` - Dataset information
|
|
948
|
-
- `participants.tsv` - Subject metadata
|
|
949
|
-
- `*_events.tsv` - Trial/window events
|
|
950
|
-
- `*_channels.tsv` - Channel information
|
|
951
|
-
- `*_eeg.json` - Recording parameters
|
|
952
|
-
|
|
953
|
-
**Data storage:**
|
|
954
|
-
- `dataset.zarr/` - Zarr format (optimized for random access)
|
|
955
|
-
|
|
956
|
-
```
|
|
957
|
-
sourcedata/braindecode/
|
|
958
|
-
├── dataset_description.json
|
|
959
|
-
├── participants.tsv
|
|
960
|
-
├── dataset.zarr/
|
|
961
|
-
└── sub-<label>/
|
|
962
|
-
└── eeg/
|
|
963
|
-
├── *_events.tsv
|
|
964
|
-
├── *_channels.tsv
|
|
965
|
-
└── *_eeg.json
|
|
966
|
-
```
|
|
967
|
-
|
|
968
|
-
### Accessing Metadata
|
|
969
|
-
|
|
970
|
-
```python
|
|
971
|
-
# Participants info
|
|
972
|
-
if hasattr(dataset, "participants"):
|
|
973
|
-
print(dataset.participants)
|
|
974
|
-
|
|
975
|
-
# Events for a recording
|
|
976
|
-
if hasattr(dataset.datasets[0], "bids_events"):
|
|
977
|
-
print(dataset.datasets[0].bids_events)
|
|
978
|
-
|
|
979
|
-
# Channel info
|
|
980
|
-
if hasattr(dataset.datasets[0], "bids_channels"):
|
|
981
|
-
print(dataset.datasets[0].bids_channels)
|
|
982
|
-
```
|
|
912
|
+
| Number of windows / samples | {n_windows} |
|
|
913
|
+
| Total size | {total_size_mb:.2f} MB |
|
|
914
|
+
| Storage format | {format} |
|
|
983
915
|
|
|
984
|
-
|
|
916
|
+
## Usage
|
|
917
|
+
|
|
918
|
+
To load this dataset::
|
|
919
|
+
|
|
920
|
+
.. code-block:: python
|
|
921
|
+
|
|
922
|
+
from braindecode.datasets import BaseConcatDataset
|
|
923
|
+
|
|
924
|
+
# Load dataset from Hugging Face Hub
|
|
925
|
+
dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
|
|
926
|
+
|
|
927
|
+
# Access data
|
|
928
|
+
X, y, metainfo = dataset[0]
|
|
929
|
+
# X: EEG data (n_channels, n_times)
|
|
930
|
+
# y: label/target
|
|
931
|
+
# metainfo: window indices
|
|
932
|
+
|
|
933
|
+
## Using with PyTorch DataLoader
|
|
934
|
+
|
|
935
|
+
::
|
|
936
|
+
|
|
937
|
+
from torch.utils.data import DataLoader
|
|
938
|
+
|
|
939
|
+
# Create DataLoader for training
|
|
940
|
+
train_loader = DataLoader(
|
|
941
|
+
dataset,
|
|
942
|
+
batch_size=32,
|
|
943
|
+
shuffle=True,
|
|
944
|
+
num_workers=4
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Training loop
|
|
948
|
+
for X, y, metainfo in train_loader:
|
|
949
|
+
# X shape: [batch_size, n_channels, n_times]
|
|
950
|
+
# y shape: [batch_size]
|
|
951
|
+
# metainfo shape: [batch_size, 2] (start and end indices)
|
|
952
|
+
# Process your batch...
|
|
953
|
+
|
|
954
|
+
## Dataset Format
|
|
955
|
+
|
|
956
|
+
This dataset is stored in **Zarr** format, optimized for:
|
|
957
|
+
- Fast random access during training (critical for PyTorch DataLoader)
|
|
958
|
+
- Efficient compression with blosc
|
|
959
|
+
- Cloud-native storage compatibility
|
|
985
960
|
|
|
986
|
-
|
|
961
|
+
For more information about braindecode, visit: https://braindecode.org
|
|
987
962
|
"""
|