braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,962 +0,0 @@
1
- """
2
- Hugging Face Hub integration for EEG datasets.
3
-
4
- This module provides push_to_hub() and pull_from_hub() functionality
5
- for braindecode datasets, similar to the model Hub integration.
6
- """
7
-
8
- # Authors: Kuntal Kokate
9
- #
10
- # License: BSD (3-clause)
11
-
12
- import io
13
- import json
14
- import logging
15
- import tempfile
16
- import warnings
17
- from pathlib import Path
18
- from typing import TYPE_CHECKING, List, Optional, Union
19
-
20
- import mne
21
- import numpy as np
22
- import pandas as pd
23
- import scipy
24
- from mne.utils import _soft_import
25
-
26
- # TODO: Simplify this logic in the future with zarr v3+ only
27
- # Optional imports for Hub functionality
28
- try:
29
- from numcodecs import Blosc, GZip, Zstd
30
-
31
- NUMCODECS_AVAILABLE = True
32
- except ImportError:
33
- NUMCODECS_AVAILABLE = False
34
- Blosc = GZip = Zstd = None
35
-
36
- if TYPE_CHECKING:
37
- from .base import BaseDataset
38
-
39
- import braindecode
40
-
41
- # Import shared validation utilities
42
- from . import hub_validation
43
-
44
- # Import registry for dynamic class lookup (avoids circular imports)
45
- from .registry import get_dataset_class, get_dataset_type
46
-
47
- # Lazy import zarr and huggingface_hub
48
- zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
49
- huggingface_hub = _soft_import(
50
- "huggingface_hub", purpose="hugging face integration", strict=False
51
- )
52
-
53
- log = logging.getLogger(__name__)
54
-
55
-
56
- class HubDatasetMixin:
57
- """
58
- Mixin class for Hugging Face Hub integration with EEG datasets.
59
-
60
- This class adds `push_to_hub()` and `pull_from_hub()` methods to
61
- BaseConcatDataset, enabling easy upload and download of datasets
62
- to/from the Hugging Face Hub.
63
-
64
- Examples
65
- --------
66
- >>> # Push dataset to Hub
67
- >>> dataset = NMT(path=path, preload=True)
68
- >>> dataset.push_to_hub(
69
- ... repo_id="username/nmt-dataset",
70
- ... commit_message="Add NMT dataset"
71
- ... )
72
- >>>
73
- >>> # Load dataset from Hub
74
- >>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
75
- """
76
-
77
- datasets: List["BaseDataset"] # Attribute provided by inheriting class
78
-
79
- def push_to_hub(
80
- self,
81
- repo_id: str,
82
- commit_message: Optional[str] = None,
83
- private: bool = False,
84
- token: Optional[str] = None,
85
- create_pr: bool = False,
86
- compression: str = "blosc",
87
- compression_level: int = 5,
88
- ) -> str:
89
- """
90
- Upload the dataset to the Hugging Face Hub in Zarr format.
91
-
92
- The dataset is converted to Zarr format with blosc compression, which provides
93
- optimal random access performance for PyTorch training (based on comprehensive
94
- benchmarking).
95
-
96
- Parameters
97
- ----------
98
- repo_id : str
99
- Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
100
- commit_message : str | None
101
- Commit message. If None, a default message is generated.
102
- private : bool, default=False
103
- Whether to create a private repository.
104
- token : str | None
105
- Hugging Face API token. If None, uses cached token.
106
- create_pr : bool, default=False
107
- Whether to create a Pull Request instead of directly committing.
108
- compression : str, default="blosc"
109
- Compression algorithm for Zarr. Options: "blosc", "zstd", "gzip", None.
110
- compression_level : int, default=5
111
- Compression level (0-9). Level 5 provides optimal balance.
112
-
113
- Returns
114
- -------
115
- str
116
- URL of the uploaded dataset on the Hub.
117
-
118
- Raises
119
- ------
120
- ImportError
121
- If huggingface-hub is not installed.
122
- ValueError
123
- If the dataset is empty or format is invalid.
124
-
125
- Examples
126
- --------
127
- >>> dataset = NMT(path=path, preload=True)
128
- >>> # Upload with default settings (zarr with blosc compression)
129
- >>> url = dataset.push_to_hub(
130
- ... repo_id="myusername/nmt-dataset",
131
- ... commit_message="Upload NMT EEG dataset"
132
- ... )
133
- >>>
134
- >>> # Or customize compression
135
- >>> url = dataset.push_to_hub(
136
- ... repo_id="myusername/nmt-dataset",
137
- ... compression="blosc",
138
- ... compression_level=5
139
- ... )
140
- """
141
- if huggingface_hub is False or zarr is False:
142
- raise ImportError(
143
- "huggingface-hub or zarr is not installed. Install with: "
144
- "pip install braindecode[hub]"
145
- )
146
-
147
- # Create API instance
148
- _ = huggingface_hub.HfApi(token=token)
149
-
150
- # Create repository if it doesn't exist
151
- try:
152
- huggingface_hub.create_repo(
153
- repo_id=repo_id,
154
- token=token,
155
- private=private,
156
- repo_type="dataset",
157
- exist_ok=True,
158
- )
159
- except Exception as e:
160
- raise RuntimeError(f"Failed to create repository: {e}")
161
-
162
- # Create a temporary directory for upload
163
- with tempfile.TemporaryDirectory() as tmpdir:
164
- tmp_path = Path(tmpdir)
165
-
166
- # Convert dataset to Zarr format
167
- log.info("Converting dataset to Zarr format...")
168
- dataset_path = tmp_path / "dataset.zarr"
169
- self._convert_to_zarr_inline(
170
- dataset_path,
171
- compression,
172
- compression_level,
173
- )
174
-
175
- # Save dataset metadata
176
- self._save_dataset_card(tmp_path)
177
-
178
- # Save format info
179
- format_info_path = tmp_path / "format_info.json"
180
- with open(format_info_path, "w") as f:
181
- format_info = self._get_format_info_inline()
182
- json.dump(
183
- {
184
- "format": "zarr",
185
- "compression": compression,
186
- "compression_level": compression_level,
187
- "braindecode_version": braindecode.__version__,
188
- **format_info,
189
- },
190
- f,
191
- indent=2,
192
- )
193
-
194
- # Default commit message
195
- if commit_message is None:
196
- commit_message = (
197
- f"Upload EEG dataset in Zarr format "
198
- f"({len(self.datasets)} recordings)"
199
- )
200
-
201
- # Upload folder to Hub
202
- log.info(f"Uploading to Hugging Face Hub ({repo_id})...")
203
- try:
204
- url = huggingface_hub.upload_folder(
205
- repo_id=repo_id,
206
- folder_path=str(tmp_path),
207
- repo_type="dataset",
208
- commit_message=commit_message,
209
- token=token,
210
- create_pr=create_pr,
211
- )
212
- log.info(f"Dataset uploaded successfully to {repo_id}")
213
- log.info(f"URL: https://huggingface.co/datasets/{repo_id}")
214
- return url
215
- except Exception as e:
216
- raise RuntimeError(f"Failed to upload dataset: {e}")
217
-
218
- def _save_dataset_card(self, path: Path) -> None:
219
- """Generate and save a dataset card (README.md) with metadata.
220
-
221
- Parameters
222
- ----------
223
- path : Path
224
- Directory where README.md will be saved.
225
- """
226
- # Get info, which also validates uniformity across all datasets
227
- format_info = self._get_format_info_inline()
228
-
229
- n_recordings = len(self.datasets)
230
- first_ds = self.datasets[0]
231
-
232
- # Get dataset-specific info based on type using registry
233
- dataset_type = get_dataset_type(first_ds)
234
-
235
- n_windows = format_info["total_samples"]
236
-
237
- if dataset_type == "WindowsDataset":
238
- n_channels = len(first_ds.windows.ch_names)
239
- data_type = "Windowed (from Epochs object)"
240
- sfreq = first_ds.windows.info["sfreq"]
241
- elif dataset_type == "EEGWindowsDataset":
242
- n_channels = len(first_ds.raw.ch_names)
243
- sfreq = first_ds.raw.info["sfreq"]
244
- data_type = "Windowed (from Raw object)"
245
- elif dataset_type == "RawDataset":
246
- n_channels = len(first_ds.raw.ch_names)
247
- sfreq = first_ds.raw.info["sfreq"]
248
- data_type = "Continuous (Raw)"
249
- else:
250
- raise TypeError(f"Unsupported dataset type: {dataset_type}")
251
-
252
- # Create README content and save
253
- readme_content = _generate_readme_content(
254
- format_info=format_info,
255
- n_recordings=n_recordings,
256
- n_channels=n_channels,
257
- sfreq=sfreq,
258
- data_type=data_type,
259
- n_windows=n_windows,
260
- )
261
-
262
- # Save README
263
- readme_path = path / "README.md"
264
- with open(readme_path, "w") as f:
265
- f.write(readme_content)
266
-
267
- @classmethod
268
- def pull_from_hub(
269
- cls,
270
- repo_id: str,
271
- preload: bool = True,
272
- token: Optional[str] = None,
273
- cache_dir: Optional[Union[str, Path]] = None,
274
- force_download: bool = False,
275
- **kwargs,
276
- ):
277
- """
278
- Load a dataset from the Hugging Face Hub.
279
-
280
- Parameters
281
- ----------
282
- repo_id : str
283
- Repository ID on the Hugging Face Hub (e.g., "username/dataset-name").
284
- preload : bool, default=True
285
- Whether to preload the data into memory. If False, uses lazy loading
286
- (when supported by the format).
287
- token : str | None
288
- Hugging Face API token. If None, uses cached token.
289
- cache_dir : str | Path | None
290
- Directory to cache the downloaded dataset. If None, uses default
291
- cache directory (~/.cache/huggingface/datasets).
292
- force_download : bool, default=False
293
- Whether to force re-download even if cached.
294
- **kwargs
295
- Additional arguments (currently unused).
296
-
297
- Returns
298
- -------
299
- BaseConcatDataset
300
- The loaded dataset.
301
-
302
- Raises
303
- ------
304
- ImportError
305
- If huggingface-hub is not installed.
306
- FileNotFoundError
307
- If the repository or dataset files are not found.
308
-
309
- Examples
310
- --------
311
- >>> from braindecode.datasets import BaseConcatDataset
312
- >>> dataset = BaseConcatDataset.pull_from_hub("username/nmt-dataset")
313
- >>> print(f"Loaded {len(dataset)} windows")
314
- >>>
315
- >>> # Use with PyTorch
316
- >>> from torch.utils.data import DataLoader
317
- >>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
318
- """
319
- if zarr is False or huggingface_hub is False:
320
- raise ImportError(
321
- "huggingface hub functionality is not installed. Install with: "
322
- "pip install braindecode[hub]"
323
- )
324
-
325
- log.info(f"Loading dataset from Hugging Face Hub ({repo_id})...")
326
-
327
- try:
328
- # Download the entire dataset directory
329
- dataset_dir = huggingface_hub.snapshot_download(
330
- repo_id=repo_id,
331
- repo_type="dataset",
332
- token=token,
333
- cache_dir=cache_dir,
334
- force_download=force_download,
335
- )
336
-
337
- # Load format info
338
- format_info_path = Path(dataset_dir) / "format_info.json"
339
- if format_info_path.exists():
340
- with open(format_info_path, "r") as f:
341
- format_info = json.load(f)
342
-
343
- # Verify it's zarr format
344
- if format_info.get("format") != "zarr":
345
- raise ValueError(
346
- f"Dataset format is '{format_info.get('format')}', but only "
347
- "'zarr' format is supported. Please re-upload the dataset."
348
- )
349
- else:
350
- format_info = {}
351
-
352
- # Load zarr dataset
353
- zarr_path = Path(dataset_dir) / "dataset.zarr"
354
- if not zarr_path.exists():
355
- raise FileNotFoundError(
356
- f"Zarr dataset not found at {zarr_path}. "
357
- "The dataset may be in an unsupported format."
358
- )
359
-
360
- dataset = cls._load_from_zarr_inline(zarr_path, preload)
361
-
362
- log.info(f"Dataset loaded successfully from {repo_id}")
363
- log.info(f"Recordings: {len(dataset.datasets)}")
364
- log.info(
365
- f"Total windows/samples: {format_info.get('total_samples', 'N/A')}"
366
- )
367
-
368
- return dataset
369
-
370
- except huggingface_hub.utils.HfHubHTTPError as e:
371
- if e.response.status_code == 404:
372
- raise FileNotFoundError(
373
- f"Dataset '{repo_id}' not found on Hugging Face Hub. "
374
- "Please check the repository ID and ensure it exists."
375
- )
376
- else:
377
- raise RuntimeError(f"Failed to download dataset: {e}")
378
- except Exception as e:
379
- raise RuntimeError(f"Failed to load dataset from Hub: {e}")
380
-
381
- def _convert_to_zarr_inline(
382
- self,
383
- output_path: Path,
384
- compression: str,
385
- compression_level: int,
386
- ) -> None:
387
- """Convert dataset to Zarr format (inline implementation)."""
388
-
389
- if zarr is False or huggingface_hub is False:
390
- raise ImportError(
391
- "huggingface hub functionality is not installed. Install with: "
392
- "pip install braindecode[hub]"
393
- )
394
-
395
- if output_path.exists():
396
- raise FileExistsError(
397
- f"{output_path} already exists. Set overwrite=True to replace it."
398
- )
399
-
400
- # Create zarr store (zarr v2 API)
401
- store = zarr.DirectoryStore(str(output_path))
402
- root = zarr.group(store=store, overwrite=False)
403
-
404
- # Validate uniformity across all datasets using shared validation
405
- dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
406
-
407
- # Keep reference to first dataset for preprocessing kwargs
408
- first_ds = self.datasets[0]
409
-
410
- # Store global metadata
411
- root.attrs["n_datasets"] = len(self.datasets)
412
- root.attrs["dataset_type"] = dataset_type
413
- root.attrs["braindecode_version"] = braindecode.__version__
414
-
415
- # Track dependency versions for reproducibility
416
- root.attrs["mne_version"] = mne.__version__
417
- root.attrs["numpy_version"] = np.__version__
418
- root.attrs["pandas_version"] = pd.__version__
419
- root.attrs["zarr_version"] = zarr.__version__
420
- root.attrs["scipy_version"] = scipy.__version__
421
-
422
- # Save preprocessing kwargs (check first dataset, assuming uniform preprocessing)
423
- # These are typically set by windowing functions on individual datasets
424
- for kwarg_name in [
425
- "raw_preproc_kwargs",
426
- "window_kwargs",
427
- "window_preproc_kwargs",
428
- ]:
429
- # Check first dataset for these attributes
430
- if hasattr(first_ds, kwarg_name):
431
- kwargs = getattr(first_ds, kwarg_name)
432
- if kwargs:
433
- root.attrs[kwarg_name] = json.dumps(kwargs)
434
-
435
- # Create compressor
436
- compressor = _create_compressor(compression, compression_level)
437
-
438
- # Save each recording
439
- for i_ds, ds in enumerate(self.datasets):
440
- grp = root.create_group(f"recording_{i_ds}")
441
-
442
- if dataset_type == "WindowsDataset":
443
- # Extract data from WindowsDataset
444
- data = ds.windows.get_data()
445
- metadata = ds.windows.metadata
446
- description = ds.description
447
- info_dict = _mne_info_to_dict(ds.windows.info)
448
- target_name = ds.target_name if hasattr(ds, "target_name") else None
449
-
450
- # Save using inlined function
451
- _save_windows_to_zarr(
452
- grp, data, metadata, description, info_dict, compressor, target_name
453
- )
454
-
455
- elif dataset_type == "EEGWindowsDataset":
456
- # Get continuous raw data and metadata from EEGWindowsDataset
457
- raw = ds.raw
458
- metadata = ds.metadata
459
- description = ds.description
460
- info_dict = _mne_info_to_dict(ds.raw.info)
461
- targets_from = ds.targets_from
462
- last_target_only = ds.last_target_only
463
-
464
- # Save using inlined function (saves continuous raw directly)
465
- _save_eegwindows_to_zarr(
466
- grp,
467
- raw,
468
- metadata,
469
- description,
470
- info_dict,
471
- targets_from,
472
- last_target_only,
473
- compressor,
474
- )
475
-
476
- elif dataset_type == "RawDataset":
477
- # Get continuous raw data from RawDataset
478
- raw = ds.raw
479
- description = ds.description
480
- info_dict = _mne_info_to_dict(ds.raw.info)
481
- target_name = ds.target_name if hasattr(ds, "target_name") else None
482
-
483
- # Save using inlined function
484
- _save_raw_to_zarr(
485
- grp, raw, description, info_dict, target_name, compressor
486
- )
487
-
488
- def _get_format_info_inline(self):
489
- """Get format info (inline implementation).
490
-
491
- This is an inline version of hub_formats.get_format_info() that avoids
492
- circular import.
493
- """
494
- if len(self.datasets) == 0:
495
- raise ValueError("Cannot get format info for empty dataset")
496
-
497
- # Validate uniformity across all datasets using shared validation
498
- dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
499
-
500
- # Calculate dataset size
501
- # BaseConcatDataset's __len__ already sums len(ds) for all datasets
502
- total_samples = len(self)
503
- total_size_mb = 0
504
-
505
- for ds in self.datasets:
506
- if dataset_type == "WindowsDataset":
507
- # Use MNE's internal _size property to avoid loading data
508
- total_size_mb += ds.windows._size / (1024 * 1024)
509
- elif dataset_type == "EEGWindowsDataset":
510
- # Use raw object's size (not extracted windows)
511
- total_size_mb += ds.raw._size / (1024 * 1024)
512
- elif dataset_type == "RawDataset":
513
- total_size_mb += ds.raw._size / (1024 * 1024)
514
-
515
- n_recordings = len(self.datasets)
516
-
517
- return {
518
- "n_recordings": n_recordings,
519
- "total_samples": total_samples,
520
- "total_size_mb": round(total_size_mb, 2),
521
- }
522
-
523
- @staticmethod
524
- def _load_from_zarr_inline(input_path: Path, preload: bool):
525
- """Load dataset from Zarr format (inline implementation).
526
-
527
- This is an inline version of hub_formats.load_from_zarr() that avoids
528
- circular import by using hub_formats_core directly.
529
- """
530
- if not input_path.exists():
531
- raise FileNotFoundError(f"{input_path} does not exist.")
532
-
533
- # Open zarr store (zarr v2 API)
534
- store = zarr.DirectoryStore(str(input_path))
535
- root = zarr.group(store=store)
536
-
537
- n_datasets = root.attrs["n_datasets"]
538
- dataset_type = root.attrs["dataset_type"]
539
-
540
- # Get dataset classes from registry
541
- WindowsDataset = get_dataset_class("WindowsDataset")
542
- EEGWindowsDataset = get_dataset_class("EEGWindowsDataset")
543
- RawDataset = get_dataset_class("RawDataset")
544
- BaseConcatDataset = get_dataset_class("BaseConcatDataset")
545
-
546
- datasets = []
547
- for i_ds in range(n_datasets):
548
- grp = root[f"recording_{i_ds}"]
549
-
550
- if dataset_type == "WindowsDataset":
551
- # Load using inlined function
552
- data, metadata, description, info_dict, target_name = (
553
- _load_windows_from_zarr(grp, preload)
554
- )
555
-
556
- # Convert to MNE objects and create dataset
557
- info = _dict_to_mne_info(info_dict)
558
- events = np.column_stack(
559
- [
560
- metadata["i_start_in_trial"].values,
561
- np.zeros(len(metadata), dtype=int),
562
- metadata["target"].values,
563
- ]
564
- )
565
- epochs = mne.EpochsArray(data, info, events=events, metadata=metadata)
566
- ds = WindowsDataset(epochs, description)
567
- if target_name is not None:
568
- ds.target_name = target_name
569
-
570
- elif dataset_type == "EEGWindowsDataset":
571
- # Load using inlined function
572
- (
573
- data,
574
- metadata,
575
- description,
576
- info_dict,
577
- targets_from,
578
- last_target_only,
579
- ) = _load_eegwindows_from_zarr(grp, preload)
580
-
581
- # Convert to MNE objects and create dataset
582
- # Data is already in continuous format [n_channels, n_timepoints]
583
- info = _dict_to_mne_info(info_dict)
584
- raw = mne.io.RawArray(data, info)
585
- ds = EEGWindowsDataset(
586
- raw=raw,
587
- metadata=metadata,
588
- description=description,
589
- targets_from=targets_from,
590
- last_target_only=last_target_only,
591
- )
592
-
593
- elif dataset_type == "RawDataset":
594
- # Load using inlined function
595
- data, description, info_dict, target_name = _load_raw_from_zarr(
596
- grp, preload
597
- )
598
-
599
- # Convert to MNE objects and create dataset
600
- # Data is in continuous format [n_channels, n_timepoints]
601
- info = _dict_to_mne_info(info_dict)
602
- raw = mne.io.RawArray(data, info)
603
- ds = RawDataset(raw, description)
604
- if target_name is not None:
605
- ds.target_name = target_name
606
-
607
- else:
608
- raise ValueError(f"Unsupported dataset_type: {dataset_type}")
609
-
610
- datasets.append(ds)
611
-
612
- # Create concat dataset
613
- concat_ds = BaseConcatDataset(datasets)
614
-
615
- # Restore preprocessing kwargs (set on individual datasets, not concat)
616
- for kwarg_name in [
617
- "raw_preproc_kwargs",
618
- "window_kwargs",
619
- "window_preproc_kwargs",
620
- ]:
621
- if kwarg_name in root.attrs:
622
- kwargs = json.loads(root.attrs[kwarg_name])
623
- # Set on each individual dataset (where they were originally stored)
624
- for ds in datasets:
625
- setattr(ds, kwarg_name, kwargs)
626
-
627
- return concat_ds
628
-
629
-
630
- # =============================================================================
631
- # Core Zarr I/O Utilities
632
- # =============================================================================
633
-
634
-
635
- # TODO: remove when this MNE is solved https://github.com/mne-tools/mne-python/issues/13487
636
- def _mne_info_to_dict(info):
637
- """Convert MNE Info object to dictionary for JSON serialization."""
638
- return {
639
- "ch_names": info["ch_names"],
640
- "sfreq": float(info["sfreq"]),
641
- "ch_types": [str(ch_type) for ch_type in info.get_channel_types()],
642
- "lowpass": float(info["lowpass"]) if info["lowpass"] is not None else None,
643
- "highpass": float(info["highpass"]) if info["highpass"] is not None else None,
644
- }
645
-
646
-
647
- def _dict_to_mne_info(info_dict):
648
- """Convert dictionary back to MNE Info object."""
649
- info = mne.create_info(
650
- ch_names=info_dict["ch_names"],
651
- sfreq=info_dict["sfreq"],
652
- ch_types=info_dict["ch_types"],
653
- )
654
-
655
- # Use _unlock() to set filter info when reconstructing from saved metadata
656
- # This is necessary because MNE protects these fields to prevent users from
657
- # setting filter parameters without actually filtering the data
658
- with info._unlock():
659
- if info_dict.get("lowpass") is not None:
660
- info["lowpass"] = info_dict["lowpass"]
661
- if info_dict.get("highpass") is not None:
662
- info["highpass"] = info_dict["highpass"]
663
-
664
- return info
665
-
666
-
667
- def _save_windows_to_zarr(
668
- grp, data, metadata, description, info, compressor, target_name
669
- ):
670
- """Save windowed data to Zarr group (low-level function)."""
671
- # Save data with chunking for random access
672
- grp.create_dataset(
673
- "data",
674
- data=data.astype(np.float32),
675
- chunks=(1, data.shape[1], data.shape[2]),
676
- compressor=compressor,
677
- )
678
-
679
- # Save metadata
680
- metadata_json = metadata.to_json(orient="split", date_format="iso")
681
- grp.attrs["metadata"] = metadata_json
682
- # Save dtypes to preserve them across platforms (int32 vs int64, etc.)
683
- metadata_dtypes = metadata.dtypes.apply(str).to_json()
684
- grp.attrs["metadata_dtypes"] = metadata_dtypes
685
-
686
- # Save description
687
- description_json = description.to_json(date_format="iso")
688
- grp.attrs["description"] = description_json
689
-
690
- # Save MNE info
691
- grp.attrs["info"] = json.dumps(info)
692
-
693
- # Save target name if provided
694
- if target_name is not None:
695
- grp.attrs["target_name"] = target_name
696
-
697
-
698
- def _save_eegwindows_to_zarr(
699
- grp, raw, metadata, description, info, targets_from, last_target_only, compressor
700
- ):
701
- """Save EEG continuous raw data to Zarr group (low-level function)."""
702
- # Extract continuous data from Raw [n_channels, n_timepoints]
703
- continuous_data = raw.get_data()
704
-
705
- # Save continuous data with chunking optimized for window extraction
706
- # Chunk size: all channels, 10000 timepoints for efficient random access
707
- grp.create_dataset(
708
- "data",
709
- data=continuous_data.astype(np.float32),
710
- chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
711
- compressor=compressor,
712
- )
713
-
714
- # Save metadata
715
- metadata_json = metadata.to_json(orient="split", date_format="iso")
716
- grp.attrs["metadata"] = metadata_json
717
- # Save dtypes to preserve them across platforms (int32 vs int64, etc.)
718
- metadata_dtypes = metadata.dtypes.apply(str).to_json()
719
- grp.attrs["metadata_dtypes"] = metadata_dtypes
720
-
721
- # Save description
722
- description_json = description.to_json(date_format="iso")
723
- grp.attrs["description"] = description_json
724
-
725
- # Save MNE info
726
- grp.attrs["info"] = json.dumps(info)
727
-
728
- # Save EEGWindowsDataset-specific attributes
729
- grp.attrs["targets_from"] = targets_from
730
- grp.attrs["last_target_only"] = last_target_only
731
-
732
-
733
- def _load_windows_from_zarr(grp, preload):
734
- """Load windowed data from Zarr group (low-level function)."""
735
- # Load metadata
736
- metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
737
- # Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
738
- dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
739
- for col, dtype_str in dtypes_dict.items():
740
- metadata[col] = metadata[col].astype(dtype_str)
741
-
742
- # Load description
743
- description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
744
-
745
- # Load info
746
- info_dict = json.loads(grp.attrs["info"])
747
-
748
- # Load data
749
- if preload:
750
- data = grp["data"][:]
751
- else:
752
- data = grp["data"][:]
753
- # TODO: Implement lazy loading properly
754
- warnings.warn(
755
- "Lazy loading from Zarr not fully implemented yet. "
756
- "Loading all data into memory.",
757
- UserWarning,
758
- )
759
-
760
- # Load target name
761
- target_name = grp.attrs.get("target_name", None)
762
-
763
- return data, metadata, description, info_dict, target_name
764
-
765
-
766
- def _load_eegwindows_from_zarr(grp, preload):
767
- """Load EEG continuous raw data from Zarr group (low-level function)."""
768
- # Load metadata
769
- metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
770
- # Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
771
- dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
772
- for col, dtype_str in dtypes_dict.items():
773
- metadata[col] = metadata[col].astype(dtype_str)
774
-
775
- # Load description
776
- description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
777
-
778
- # Load info
779
- info_dict = json.loads(grp.attrs["info"])
780
-
781
- # Load data
782
- if preload:
783
- data = grp["data"][:]
784
- else:
785
- data = grp["data"][:]
786
- warnings.warn(
787
- "Lazy loading from Zarr not fully implemented yet. "
788
- "Loading all data into memory.",
789
- UserWarning,
790
- )
791
-
792
- # Load EEGWindowsDataset-specific attributes
793
- targets_from = grp.attrs.get("targets_from", "metadata")
794
- last_target_only = grp.attrs.get("last_target_only", True)
795
-
796
- return data, metadata, description, info_dict, targets_from, last_target_only
797
-
798
-
799
- def _save_raw_to_zarr(grp, raw, description, info, target_name, compressor):
800
- """Save RawDataset continuous raw data to Zarr group (low-level function)."""
801
- # Extract continuous data from Raw [n_channels, n_timepoints]
802
- continuous_data = raw.get_data()
803
-
804
- # Save continuous data with chunking optimized for efficient access
805
- # Chunk size: all channels, 10000 timepoints for efficient random access
806
- grp.create_dataset(
807
- "data",
808
- data=continuous_data.astype(np.float32),
809
- chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
810
- compressor=compressor,
811
- )
812
-
813
- # Save description
814
- description_json = description.to_json(date_format="iso")
815
- grp.attrs["description"] = description_json
816
-
817
- # Save MNE info
818
- grp.attrs["info"] = json.dumps(info)
819
-
820
- # Save target name if provided
821
- if target_name is not None:
822
- grp.attrs["target_name"] = target_name
823
-
824
-
825
- def _load_raw_from_zarr(grp, preload):
826
- """Load RawDataset continuous raw data from Zarr group (low-level function)."""
827
- # Load description
828
- description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
829
-
830
- # Load info
831
- info_dict = json.loads(grp.attrs["info"])
832
-
833
- # Load data
834
- if preload:
835
- data = grp["data"][:]
836
- else:
837
- data = grp["data"][:]
838
- # TODO: Implement lazy loading properly
839
- warnings.warn(
840
- "Lazy loading from Zarr not fully implemented yet. "
841
- "Loading all data into memory.",
842
- UserWarning,
843
- )
844
-
845
- # Load target name
846
- target_name = grp.attrs.get("target_name", None)
847
-
848
- return data, description, info_dict, target_name
849
-
850
-
851
- def _create_compressor(compression, compression_level):
852
- """Create a Zarr compressor object (zarr v2 API)."""
853
- if zarr is False:
854
- raise ImportError(
855
- "Zarr is not installed. Install with: pip install braindecode[hub]"
856
- )
857
-
858
- if not NUMCODECS_AVAILABLE:
859
- raise ImportError(
860
- "numcodecs is not installed. Install with: pip install braindecode[hub]"
861
- )
862
-
863
- # Zarr v2 uses numcodecs compressors
864
- if compression == "blosc":
865
- return Blosc(cname="zstd", clevel=compression_level)
866
- elif compression == "zstd":
867
- return Zstd(level=compression_level)
868
- elif compression == "gzip":
869
- return GZip(level=compression_level)
870
- else:
871
- return None
872
-
873
-
874
- # TODO: improve content
875
- def _generate_readme_content(
876
- format_info,
877
- n_recordings: int,
878
- n_channels: int,
879
- sfreq,
880
- data_type: str,
881
- n_windows: int,
882
- format: str = "zarr",
883
- ):
884
- """Generate README.md content for a dataset uploaded to the Hub."""
885
- # Use safe access for total size and format sfreq nicely
886
- total_size_mb = (
887
- format_info.get("total_size_mb", 0.0) if isinstance(format_info, dict) else 0.0
888
- )
889
- sfreq_str = f"{sfreq:g}" if sfreq is not None else "N/A"
890
-
891
- return f"""---
892
- tags:
893
- - braindecode
894
- - eeg
895
- - neuroscience
896
- - brain-computer-interface
897
- license: unknown
898
- ---
899
-
900
- # EEG Dataset
901
-
902
- This dataset was created using [braindecode](https://braindecode.org), a library for deep learning with EEG/MEG/ECoG signals.
903
-
904
- ## Dataset Information
905
-
906
- | Property | Value |
907
- |---|---:|
908
- | Number of recordings | {n_recordings} |
909
- | Dataset type | {data_type} |
910
- | Number of channels | {n_channels} |
911
- | Sampling frequency | {sfreq_str} Hz |
912
- | Number of windows / samples | {n_windows} |
913
- | Total size | {total_size_mb:.2f} MB |
914
- | Storage format | {format} |
915
-
916
- ## Usage
917
-
918
- To load this dataset::
919
-
920
- .. code-block:: python
921
-
922
- from braindecode.datasets import BaseConcatDataset
923
-
924
- # Load dataset from Hugging Face Hub
925
- dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
926
-
927
- # Access data
928
- X, y, metainfo = dataset[0]
929
- # X: EEG data (n_channels, n_times)
930
- # y: label/target
931
- # metainfo: window indices
932
-
933
- ## Using with PyTorch DataLoader
934
-
935
- ::
936
-
937
- from torch.utils.data import DataLoader
938
-
939
- # Create DataLoader for training
940
- train_loader = DataLoader(
941
- dataset,
942
- batch_size=32,
943
- shuffle=True,
944
- num_workers=4
945
- )
946
-
947
- # Training loop
948
- for X, y, metainfo in train_loader:
949
- # X shape: [batch_size, n_channels, n_times]
950
- # y shape: [batch_size]
951
- # metainfo shape: [batch_size, 2] (start and end indices)
952
- # Process your batch...
953
-
954
- ## Dataset Format
955
-
956
- This dataset is stored in **Zarr** format, optimized for:
957
- - Fast random access during training (critical for PyTorch DataLoader)
958
- - Efficient compression with blosc
959
- - Cloud-native storage compatibility
960
-
961
- For more information about braindecode, visit: https://braindecode.org
962
- """