braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,19 @@
1
- # mypy: ignore-errors
2
1
  """
3
2
  Hugging Face Hub integration for EEG datasets.
4
3
 
5
4
  This module provides push_to_hub() and pull_from_hub() functionality
6
5
  for braindecode datasets, similar to the model Hub integration.
7
-
8
- .. warning::
9
- The format is **BIDS-inspired**, not **BIDS-compliant**. The metadata
10
- files are BIDS-compliant, but the data is stored in Zarr format for
11
- efficient training, which is not a valid BIDS EEG format.
12
-
13
- The format follows a BIDS-inspired sourcedata structure:
14
- - sourcedata/braindecode/
15
- - dataset_description.json (BIDS-compliant)
16
- - participants.tsv (BIDS-compliant)
17
- - dataset.zarr/ (NOT BIDS-compliant - efficient data store)
18
- - sub-<label>/
19
- - eeg/
20
- - *_events.tsv (BIDS-compliant)
21
- - *_channels.tsv (BIDS-compliant)
22
- - *_eeg.json (BIDS-compliant)
23
6
  """
24
7
 
25
8
  # Authors: Kuntal Kokate
26
- # Bruno Aristimunha <b.aristimunha@gmail.com>
27
9
  #
28
10
  # License: BSD (3-clause)
29
11
 
12
+ import io
30
13
  import json
31
14
  import logging
32
15
  import tempfile
33
- from datetime import timedelta
16
+ import warnings
34
17
  from pathlib import Path
35
18
  from typing import TYPE_CHECKING, List, Optional, Union
36
19
 
@@ -38,28 +21,28 @@ import mne
38
21
  import numpy as np
39
22
  import pandas as pd
40
23
  import scipy
41
- from mne._fiff.meas_info import Info
42
24
  from mne.utils import _soft_import
43
25
 
26
+ # TODO: Simplify this logic in the future with zarr v3+ only
27
+ # Optional imports for Hub functionality
28
+ try:
29
+ from numcodecs import Blosc, GZip, Zstd
30
+
31
+ NUMCODECS_AVAILABLE = True
32
+ except ImportError:
33
+ NUMCODECS_AVAILABLE = False
34
+ Blosc = GZip = Zstd = None
35
+
44
36
  if TYPE_CHECKING:
45
- from ..base import BaseDataset
37
+ from .base import BaseDataset
46
38
 
47
39
  import braindecode
48
40
 
41
+ # Import shared validation utilities
42
+ from . import hub_validation
43
+
49
44
  # Import registry for dynamic class lookup (avoids circular imports)
50
- from ..registry import get_dataset_class, get_dataset_type
51
-
52
- # Hub format and validation utilities
53
- from . import hub_format, hub_validation
54
- from .hub_io import (
55
- _create_compressor,
56
- _load_eegwindows_from_zarr,
57
- _load_raw_from_zarr,
58
- _load_windows_from_zarr,
59
- _save_eegwindows_to_zarr,
60
- _save_raw_to_zarr,
61
- _save_windows_to_zarr,
62
- )
45
+ from .registry import get_dataset_class, get_dataset_type
63
46
 
64
47
  # Lazy import zarr and huggingface_hub
65
48
  zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
@@ -102,15 +85,13 @@ class HubDatasetMixin:
102
85
  create_pr: bool = False,
103
86
  compression: str = "blosc",
104
87
  compression_level: int = 5,
105
- pipeline_name: str = "braindecode",
106
88
  ) -> str:
107
89
  """
108
- Upload the dataset to the Hugging Face Hub in BIDS-like Zarr format.
90
+ Upload the dataset to the Hugging Face Hub in Zarr format.
109
91
 
110
92
  The dataset is converted to Zarr format with blosc compression, which provides
111
- optimal random access performance for PyTorch training. The data is stored
112
- in a BIDS sourcedata-like structure with events.tsv, channels.tsv,
113
- and participants.tsv sidecar files.
93
+ optimal random access performance for PyTorch training (based on comprehensive
94
+ benchmarking).
114
95
 
115
96
  Parameters
116
97
  ----------
@@ -128,8 +109,6 @@ class HubDatasetMixin:
128
109
  Compression algorithm for Zarr. Options: "blosc", "zstd", "gzip", None.
129
110
  compression_level : int, default=5
130
111
  Compression level (0-9). Level 5 provides optimal balance.
131
- pipeline_name : str, default="braindecode"
132
- Name of the processing pipeline for BIDS sourcedata.
133
112
 
134
113
  Returns
135
114
  -------
@@ -146,11 +125,18 @@ class HubDatasetMixin:
146
125
  Examples
147
126
  --------
148
127
  >>> dataset = NMT(path=path, preload=True)
149
- >>> # Upload with BIDS-like structure
128
+ >>> # Upload with default settings (zarr with blosc compression)
150
129
  >>> url = dataset.push_to_hub(
151
130
  ... repo_id="myusername/nmt-dataset",
152
131
  ... commit_message="Upload NMT EEG dataset"
153
132
  ... )
133
+ >>>
134
+ >>> # Or customize compression
135
+ >>> url = dataset.push_to_hub(
136
+ ... repo_id="myusername/nmt-dataset",
137
+ ... compression="blosc",
138
+ ... compression_level=5
139
+ ... )
154
140
  """
155
141
  if huggingface_hub is False or zarr is False:
156
142
  raise ImportError(
@@ -177,44 +163,25 @@ class HubDatasetMixin:
177
163
  with tempfile.TemporaryDirectory() as tmpdir:
178
164
  tmp_path = Path(tmpdir)
179
165
 
180
- # Create BIDS-like sourcedata structure
181
- log.info("Creating BIDS-like sourcedata structure...")
182
- bids_layout = hub_format.BIDSSourcedataLayout(
183
- tmp_path, pipeline_name=pipeline_name
184
- )
185
- sourcedata_dir = bids_layout.create_structure()
186
-
187
- # Save dataset_description.json
188
- bids_layout.save_dataset_description()
189
-
190
- # Save participants.tsv
191
- descriptions = [ds.description for ds in self.datasets]
192
- bids_layout.save_participants(descriptions)
193
-
194
- # Save BIDS sidecar files for each recording
195
- self._save_bids_sidecar_files(bids_layout)
196
-
197
- # Convert dataset to Zarr format inside sourcedata
166
+ # Convert dataset to Zarr format
198
167
  log.info("Converting dataset to Zarr format...")
199
- dataset_path = sourcedata_dir / "dataset.zarr"
200
-
168
+ dataset_path = tmp_path / "dataset.zarr"
201
169
  self._convert_to_zarr_inline(
202
170
  dataset_path,
203
171
  compression,
204
172
  compression_level,
205
173
  )
206
174
 
207
- # Save dataset metadata (README.md)
175
+ # Save dataset metadata
208
176
  self._save_dataset_card(tmp_path)
209
177
 
210
178
  # Save format info
211
179
  format_info_path = tmp_path / "format_info.json"
212
- with open(format_info_path, "w", encoding="utf-8") as f:
180
+ with open(format_info_path, "w") as f:
213
181
  format_info = self._get_format_info_inline()
214
182
  json.dump(
215
183
  {
216
184
  "format": "zarr",
217
- "pipeline_name": pipeline_name,
218
185
  "compression": compression,
219
186
  "compression_level": compression_level,
220
187
  "braindecode_version": braindecode.__version__,
@@ -227,8 +194,8 @@ class HubDatasetMixin:
227
194
  # Default commit message
228
195
  if commit_message is None:
229
196
  commit_message = (
230
- f"Upload EEG dataset in BIDS-like "
231
- f"Zarr format ({len(self.datasets)} recordings)"
197
+ f"Upload EEG dataset in Zarr format "
198
+ f"({len(self.datasets)} recordings)"
232
199
  )
233
200
 
234
201
  # Upload folder to Hub
@@ -248,15 +215,13 @@ class HubDatasetMixin:
248
215
  except Exception as e:
249
216
  raise RuntimeError(f"Failed to upload dataset: {e}")
250
217
 
251
- def _save_dataset_card(self, path: Path, bids_inspired: bool = True) -> None:
218
+ def _save_dataset_card(self, path: Path) -> None:
252
219
  """Generate and save a dataset card (README.md) with metadata.
253
220
 
254
221
  Parameters
255
222
  ----------
256
223
  path : Path
257
224
  Directory where README.md will be saved.
258
- bids_inspired : bool
259
- Whether to include BIDS-inspired format documentation.
260
225
  """
261
226
  # Get info, which also validates uniformity across all datasets
262
227
  format_info = self._get_format_info_inline()
@@ -269,27 +234,18 @@ class HubDatasetMixin:
269
234
 
270
235
  n_windows = format_info["total_samples"]
271
236
 
272
- # Compute total duration across all recordings
273
- total_duration = 0.0
274
237
  if dataset_type == "WindowsDataset":
275
238
  n_channels = len(first_ds.windows.ch_names)
276
239
  data_type = "Windowed (from Epochs object)"
277
240
  sfreq = first_ds.windows.info["sfreq"]
278
- for ds in self.datasets:
279
- epoch_length = ds.windows.tmax - ds.windows.tmin
280
- total_duration += len(ds.windows) * epoch_length
281
241
  elif dataset_type == "EEGWindowsDataset":
282
242
  n_channels = len(first_ds.raw.ch_names)
283
243
  sfreq = first_ds.raw.info["sfreq"]
284
244
  data_type = "Windowed (from Raw object)"
285
- for ds in self.datasets:
286
- total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
287
245
  elif dataset_type == "RawDataset":
288
246
  n_channels = len(first_ds.raw.ch_names)
289
247
  sfreq = first_ds.raw.info["sfreq"]
290
248
  data_type = "Continuous (Raw)"
291
- for ds in self.datasets:
292
- total_duration += ds.raw.n_times / ds.raw.info["sfreq"]
293
249
  else:
294
250
  raise TypeError(f"Unsupported dataset type: {dataset_type}")
295
251
 
@@ -301,99 +257,13 @@ class HubDatasetMixin:
301
257
  sfreq=sfreq,
302
258
  data_type=data_type,
303
259
  n_windows=n_windows,
304
- total_duration=total_duration,
305
260
  )
306
261
 
307
262
  # Save README
308
263
  readme_path = path / "README.md"
309
- with open(readme_path, "w", encoding="utf-8") as f:
264
+ with open(readme_path, "w") as f:
310
265
  f.write(readme_content)
311
266
 
312
- def _save_bids_sidecar_files(
313
- self, bids_layout: "hub_format.BIDSSourcedataLayout"
314
- ) -> None:
315
- """Save BIDS-compliant sidecar files for each recording.
316
-
317
- This creates events.tsv, channels.tsv, and EEG sidecar JSON files
318
- for each recording in a BIDS-like directory structure.
319
-
320
- Parameters
321
- ----------
322
- bids_layout : BIDSSourcedataLayout
323
- BIDS layout object for path generation.
324
- """
325
- dataset_type = get_dataset_type(self.datasets[0])
326
-
327
- for i_ds, ds in enumerate(self.datasets):
328
- # Get BIDS entities from description
329
- description = ds.description if ds.description is not None else pd.Series()
330
-
331
- # Get BIDSPath for this recording using mne_bids
332
- bids_path = bids_layout.get_bids_path(description)
333
-
334
- # Create subject directory
335
- bids_path.mkdir(exist_ok=True)
336
-
337
- # Get metadata and info based on dataset type
338
- # Also compute recording_duration, recording_type, and epoch_length
339
- recording_duration = None
340
- recording_type = None
341
- epoch_length = None
342
-
343
- if dataset_type == "WindowsDataset":
344
- info = ds.windows.info
345
- metadata = ds.windows.metadata
346
- sfreq = info["sfreq"]
347
- # WindowsDataset contains pre-cut epochs
348
- recording_type = "epoched"
349
- # Use MNE's tmax - tmin for epoch length
350
- epoch_length = ds.windows.tmax - ds.windows.tmin
351
- # Total duration = number of epochs * epoch length
352
- n_epochs = len(ds.windows)
353
- recording_duration = n_epochs * epoch_length
354
- elif dataset_type == "EEGWindowsDataset":
355
- info = ds.raw.info
356
- metadata = ds.metadata
357
- sfreq = info["sfreq"]
358
- # EEGWindowsDataset has continuous raw with window metadata
359
- recording_type = "epoched"
360
- # Use MNE Raw's duration property
361
- recording_duration = ds.raw.duration
362
- # Compute epoch_length from metadata if available
363
- if metadata is not None and len(metadata) > 0:
364
- i_start = metadata["i_start_in_trial"].iloc[0]
365
- i_stop = metadata["i_stop_in_trial"].iloc[0]
366
- epoch_length = (i_stop - i_start) / sfreq
367
- elif dataset_type == "RawDataset":
368
- info = ds.raw.info
369
- metadata = None
370
- sfreq = info["sfreq"]
371
- # RawDataset is continuous
372
- recording_type = "continuous"
373
- # Use MNE Raw's duration property
374
- recording_duration = ds.raw.duration
375
- else:
376
- continue
377
-
378
- # Determine task name from description or BIDSPath
379
- task_name = bids_path.task or "unknown"
380
-
381
- # Save BIDS sidecar files using mne_bids BIDSPath
382
- hub_format.save_bids_sidecar_files(
383
- bids_path=bids_path,
384
- info=info,
385
- metadata=metadata,
386
- sfreq=sfreq,
387
- task_name=str(task_name),
388
- recording_duration=recording_duration,
389
- recording_type=recording_type,
390
- epoch_length=epoch_length,
391
- )
392
-
393
- log.debug(
394
- f"Saved BIDS sidecar files for recording {i_ds} to {bids_path.directory}"
395
- )
396
-
397
267
  @classmethod
398
268
  def pull_from_hub(
399
269
  cls,
@@ -479,19 +349,8 @@ class HubDatasetMixin:
479
349
  else:
480
350
  format_info = {}
481
351
 
482
- pipeline_name = format_info.get("pipeline_name", "braindecode")
483
-
484
- # Find zarr dataset path (try sourcedata, derivatives, then root)
485
- zarr_path = (
486
- Path(dataset_dir) / "sourcedata" / pipeline_name / "dataset.zarr"
487
- )
488
- if not zarr_path.exists():
489
- zarr_path = (
490
- Path(dataset_dir) / "derivatives" / pipeline_name / "dataset.zarr"
491
- )
492
- if not zarr_path.exists():
493
- zarr_path = Path(dataset_dir) / "dataset.zarr"
494
-
352
+ # Load zarr dataset
353
+ zarr_path = Path(dataset_dir) / "dataset.zarr"
495
354
  if not zarr_path.exists():
496
355
  raise FileNotFoundError(
497
356
  f"Zarr dataset not found at {zarr_path}. "
@@ -500,9 +359,6 @@ class HubDatasetMixin:
500
359
 
501
360
  dataset = cls._load_from_zarr_inline(zarr_path, preload)
502
361
 
503
- # Load BIDS metadata if available
504
- cls._load_bids_metadata(dataset, Path(dataset_dir), pipeline_name)
505
-
506
362
  log.info(f"Dataset loaded successfully from {repo_id}")
507
363
  log.info(f"Recordings: {len(dataset.datasets)}")
508
364
  log.info(
@@ -522,74 +378,6 @@ class HubDatasetMixin:
522
378
  except Exception as e:
523
379
  raise RuntimeError(f"Failed to load dataset from Hub: {e}")
524
380
 
525
- @classmethod
526
- def _load_bids_metadata(
527
- cls,
528
- dataset,
529
- dataset_dir: Path,
530
- pipeline_name: str,
531
- ) -> None:
532
- """Load BIDS metadata from sidecar files and attach to dataset.
533
-
534
- Parameters
535
- ----------
536
- dataset : BaseConcatDataset
537
- The loaded dataset to attach metadata to.
538
- dataset_dir : Path
539
- Root directory of the downloaded dataset.
540
- pipeline_name : str
541
- Name of the processing pipeline.
542
- """
543
- # Try sourcedata first, fall back to derivatives for backwards compatibility
544
- sourcedata_dir = dataset_dir / "sourcedata" / pipeline_name
545
- if not sourcedata_dir.exists():
546
- sourcedata_dir = dataset_dir / "derivatives" / pipeline_name
547
-
548
- # Load participants.tsv if available
549
- participants_path = sourcedata_dir / "participants.tsv"
550
- if participants_path.exists():
551
- try:
552
- participants_df = pd.read_csv(participants_path, sep="\t")
553
- # Store as attribute on the concat dataset
554
- dataset.participants = participants_df
555
- log.debug(
556
- f"Loaded participants info for {len(participants_df)} subjects"
557
- )
558
- except Exception as e:
559
- log.warning(f"Failed to load participants.tsv: {e}")
560
-
561
- # Create layout for path generation
562
- bids_layout = hub_format.BIDSSourcedataLayout(
563
- dataset_dir, pipeline_name=pipeline_name
564
- )
565
-
566
- # Try to load events.tsv files and attach to individual datasets
567
- for i_ds, ds in enumerate(dataset.datasets):
568
- description = ds.description if ds.description is not None else pd.Series()
569
-
570
- # Get BIDSPath for this recording
571
- bids_path = bids_layout.get_bids_path(description)
572
-
573
- # Load events.tsv if available
574
- events_path = bids_path.copy().update(suffix="events", extension=".tsv")
575
- if events_path.fpath.exists():
576
- try:
577
- events_df = pd.read_csv(events_path.fpath, sep="\t")
578
- ds.bids_events = events_df
579
- log.debug(f"Loaded events for recording {i_ds}")
580
- except Exception as e:
581
- log.warning(f"Failed to load events for recording {i_ds}: {e}")
582
-
583
- # Load channels.tsv if available
584
- channels_path = bids_path.copy().update(suffix="channels", extension=".tsv")
585
- if channels_path.fpath.exists():
586
- try:
587
- channels_df = pd.read_csv(channels_path.fpath, sep="\t")
588
- ds.bids_channels = channels_df
589
- log.debug(f"Loaded channels for recording {i_ds}")
590
- except Exception as e:
591
- log.warning(f"Failed to load channels for recording {i_ds}: {e}")
592
-
593
381
  def _convert_to_zarr_inline(
594
382
  self,
595
383
  output_path: Path,
@@ -609,8 +397,9 @@ class HubDatasetMixin:
609
397
  f"{output_path} already exists. Set overwrite=True to replace it."
610
398
  )
611
399
 
612
- # Create zarr store (zarr v3 API)
613
- root = zarr.open(str(output_path), mode="w")
400
+ # Create zarr store (zarr v2 API)
401
+ store = zarr.DirectoryStore(str(output_path))
402
+ root = zarr.group(store=store, overwrite=False)
614
403
 
615
404
  # Validate uniformity across all datasets using shared validation
616
405
  dataset_type, _, _ = hub_validation.validate_dataset_uniformity(self.datasets)
@@ -655,7 +444,7 @@ class HubDatasetMixin:
655
444
  data = ds.windows.get_data()
656
445
  metadata = ds.windows.metadata
657
446
  description = ds.description
658
- info_dict = ds.windows.info.to_json_dict()
447
+ info_dict = _mne_info_to_dict(ds.windows.info)
659
448
  target_name = ds.target_name if hasattr(ds, "target_name") else None
660
449
 
661
450
  # Save using inlined function
@@ -668,7 +457,7 @@ class HubDatasetMixin:
668
457
  raw = ds.raw
669
458
  metadata = ds.metadata
670
459
  description = ds.description
671
- info_dict = ds.raw.info.to_json_dict()
460
+ info_dict = _mne_info_to_dict(ds.raw.info)
672
461
  targets_from = ds.targets_from
673
462
  last_target_only = ds.last_target_only
674
463
 
@@ -688,7 +477,7 @@ class HubDatasetMixin:
688
477
  # Get continuous raw data from RawDataset
689
478
  raw = ds.raw
690
479
  description = ds.description
691
- info_dict = ds.raw.info.to_json_dict()
480
+ info_dict = _mne_info_to_dict(ds.raw.info)
692
481
  target_name = ds.target_name if hasattr(ds, "target_name") else None
693
482
 
694
483
  # Save using inlined function
@@ -741,8 +530,9 @@ class HubDatasetMixin:
741
530
  if not input_path.exists():
742
531
  raise FileNotFoundError(f"{input_path} does not exist.")
743
532
 
744
- # Open zarr store (zarr v3 API)
745
- root = zarr.open(str(input_path), mode="r")
533
+ # Open zarr store (zarr v2 API)
534
+ store = zarr.DirectoryStore(str(input_path))
535
+ root = zarr.group(store=store)
746
536
 
747
537
  n_datasets = root.attrs["n_datasets"]
748
538
  dataset_type = root.attrs["dataset_type"]
@@ -764,7 +554,7 @@ class HubDatasetMixin:
764
554
  )
765
555
 
766
556
  # Convert to MNE objects and create dataset
767
- info = Info.from_json_dict(info_dict)
557
+ info = _dict_to_mne_info(info_dict)
768
558
  events = np.column_stack(
769
559
  [
770
560
  metadata["i_start_in_trial"].values,
@@ -790,7 +580,7 @@ class HubDatasetMixin:
790
580
 
791
581
  # Convert to MNE objects and create dataset
792
582
  # Data is already in continuous format [n_channels, n_timepoints]
793
- info = Info.from_json_dict(info_dict)
583
+ info = _dict_to_mne_info(info_dict)
794
584
  raw = mne.io.RawArray(data, info)
795
585
  ds = EEGWindowsDataset(
796
586
  raw=raw,
@@ -808,7 +598,7 @@ class HubDatasetMixin:
808
598
 
809
599
  # Convert to MNE objects and create dataset
810
600
  # Data is in continuous format [n_channels, n_timepoints]
811
- info = Info.from_json_dict(info_dict)
601
+ info = _dict_to_mne_info(info_dict)
812
602
  raw = mne.io.RawArray(data, info)
813
603
  ds = RawDataset(raw, description)
814
604
  if target_name is not None:
@@ -837,6 +627,251 @@ class HubDatasetMixin:
837
627
  return concat_ds
838
628
 
839
629
 
630
+ # =============================================================================
631
+ # Core Zarr I/O Utilities
632
+ # =============================================================================
633
+
634
+
635
+ # TODO: remove when this MNE is solved https://github.com/mne-tools/mne-python/issues/13487
636
+ def _mne_info_to_dict(info):
637
+ """Convert MNE Info object to dictionary for JSON serialization."""
638
+ return {
639
+ "ch_names": info["ch_names"],
640
+ "sfreq": float(info["sfreq"]),
641
+ "ch_types": [str(ch_type) for ch_type in info.get_channel_types()],
642
+ "lowpass": float(info["lowpass"]) if info["lowpass"] is not None else None,
643
+ "highpass": float(info["highpass"]) if info["highpass"] is not None else None,
644
+ }
645
+
646
+
647
+ def _dict_to_mne_info(info_dict):
648
+ """Convert dictionary back to MNE Info object."""
649
+ info = mne.create_info(
650
+ ch_names=info_dict["ch_names"],
651
+ sfreq=info_dict["sfreq"],
652
+ ch_types=info_dict["ch_types"],
653
+ )
654
+
655
+ # Use _unlock() to set filter info when reconstructing from saved metadata
656
+ # This is necessary because MNE protects these fields to prevent users from
657
+ # setting filter parameters without actually filtering the data
658
+ with info._unlock():
659
+ if info_dict.get("lowpass") is not None:
660
+ info["lowpass"] = info_dict["lowpass"]
661
+ if info_dict.get("highpass") is not None:
662
+ info["highpass"] = info_dict["highpass"]
663
+
664
+ return info
665
+
666
+
667
+ def _save_windows_to_zarr(
668
+ grp, data, metadata, description, info, compressor, target_name
669
+ ):
670
+ """Save windowed data to Zarr group (low-level function)."""
671
+ # Save data with chunking for random access
672
+ grp.create_dataset(
673
+ "data",
674
+ data=data.astype(np.float32),
675
+ chunks=(1, data.shape[1], data.shape[2]),
676
+ compressor=compressor,
677
+ )
678
+
679
+ # Save metadata
680
+ metadata_json = metadata.to_json(orient="split", date_format="iso")
681
+ grp.attrs["metadata"] = metadata_json
682
+ # Save dtypes to preserve them across platforms (int32 vs int64, etc.)
683
+ metadata_dtypes = metadata.dtypes.apply(str).to_json()
684
+ grp.attrs["metadata_dtypes"] = metadata_dtypes
685
+
686
+ # Save description
687
+ description_json = description.to_json(date_format="iso")
688
+ grp.attrs["description"] = description_json
689
+
690
+ # Save MNE info
691
+ grp.attrs["info"] = json.dumps(info)
692
+
693
+ # Save target name if provided
694
+ if target_name is not None:
695
+ grp.attrs["target_name"] = target_name
696
+
697
+
698
+ def _save_eegwindows_to_zarr(
699
+ grp, raw, metadata, description, info, targets_from, last_target_only, compressor
700
+ ):
701
+ """Save EEG continuous raw data to Zarr group (low-level function)."""
702
+ # Extract continuous data from Raw [n_channels, n_timepoints]
703
+ continuous_data = raw.get_data()
704
+
705
+ # Save continuous data with chunking optimized for window extraction
706
+ # Chunk size: all channels, 10000 timepoints for efficient random access
707
+ grp.create_dataset(
708
+ "data",
709
+ data=continuous_data.astype(np.float32),
710
+ chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
711
+ compressor=compressor,
712
+ )
713
+
714
+ # Save metadata
715
+ metadata_json = metadata.to_json(orient="split", date_format="iso")
716
+ grp.attrs["metadata"] = metadata_json
717
+ # Save dtypes to preserve them across platforms (int32 vs int64, etc.)
718
+ metadata_dtypes = metadata.dtypes.apply(str).to_json()
719
+ grp.attrs["metadata_dtypes"] = metadata_dtypes
720
+
721
+ # Save description
722
+ description_json = description.to_json(date_format="iso")
723
+ grp.attrs["description"] = description_json
724
+
725
+ # Save MNE info
726
+ grp.attrs["info"] = json.dumps(info)
727
+
728
+ # Save EEGWindowsDataset-specific attributes
729
+ grp.attrs["targets_from"] = targets_from
730
+ grp.attrs["last_target_only"] = last_target_only
731
+
732
+
733
+ def _load_windows_from_zarr(grp, preload):
734
+ """Load windowed data from Zarr group (low-level function)."""
735
+ # Load metadata
736
+ metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
737
+ # Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
738
+ dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
739
+ for col, dtype_str in dtypes_dict.items():
740
+ metadata[col] = metadata[col].astype(dtype_str)
741
+
742
+ # Load description
743
+ description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
744
+
745
+ # Load info
746
+ info_dict = json.loads(grp.attrs["info"])
747
+
748
+ # Load data
749
+ if preload:
750
+ data = grp["data"][:]
751
+ else:
752
+ data = grp["data"][:]
753
+ # TODO: Implement lazy loading properly
754
+ warnings.warn(
755
+ "Lazy loading from Zarr not fully implemented yet. "
756
+ "Loading all data into memory.",
757
+ UserWarning,
758
+ )
759
+
760
+ # Load target name
761
+ target_name = grp.attrs.get("target_name", None)
762
+
763
+ return data, metadata, description, info_dict, target_name
764
+
765
+
766
+ def _load_eegwindows_from_zarr(grp, preload):
767
+ """Load EEG continuous raw data from Zarr group (low-level function)."""
768
+ # Load metadata
769
+ metadata = pd.read_json(io.StringIO(grp.attrs["metadata"]), orient="split")
770
+ # Restore dtypes to preserve them across platforms (int32 vs int64, etc.)
771
+ dtypes_dict = pd.read_json(io.StringIO(grp.attrs["metadata_dtypes"]), typ="series")
772
+ for col, dtype_str in dtypes_dict.items():
773
+ metadata[col] = metadata[col].astype(dtype_str)
774
+
775
+ # Load description
776
+ description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
777
+
778
+ # Load info
779
+ info_dict = json.loads(grp.attrs["info"])
780
+
781
+ # Load data
782
+ if preload:
783
+ data = grp["data"][:]
784
+ else:
785
+ data = grp["data"][:]
786
+ warnings.warn(
787
+ "Lazy loading from Zarr not fully implemented yet. "
788
+ "Loading all data into memory.",
789
+ UserWarning,
790
+ )
791
+
792
+ # Load EEGWindowsDataset-specific attributes
793
+ targets_from = grp.attrs.get("targets_from", "metadata")
794
+ last_target_only = grp.attrs.get("last_target_only", True)
795
+
796
+ return data, metadata, description, info_dict, targets_from, last_target_only
797
+
798
+
799
+ def _save_raw_to_zarr(grp, raw, description, info, target_name, compressor):
800
+ """Save RawDataset continuous raw data to Zarr group (low-level function)."""
801
+ # Extract continuous data from Raw [n_channels, n_timepoints]
802
+ continuous_data = raw.get_data()
803
+
804
+ # Save continuous data with chunking optimized for efficient access
805
+ # Chunk size: all channels, 10000 timepoints for efficient random access
806
+ grp.create_dataset(
807
+ "data",
808
+ data=continuous_data.astype(np.float32),
809
+ chunks=(continuous_data.shape[0], min(10000, continuous_data.shape[1])),
810
+ compressor=compressor,
811
+ )
812
+
813
+ # Save description
814
+ description_json = description.to_json(date_format="iso")
815
+ grp.attrs["description"] = description_json
816
+
817
+ # Save MNE info
818
+ grp.attrs["info"] = json.dumps(info)
819
+
820
+ # Save target name if provided
821
+ if target_name is not None:
822
+ grp.attrs["target_name"] = target_name
823
+
824
+
825
+ def _load_raw_from_zarr(grp, preload):
826
+ """Load RawDataset continuous raw data from Zarr group (low-level function)."""
827
+ # Load description
828
+ description = pd.read_json(io.StringIO(grp.attrs["description"]), typ="series")
829
+
830
+ # Load info
831
+ info_dict = json.loads(grp.attrs["info"])
832
+
833
+ # Load data
834
+ if preload:
835
+ data = grp["data"][:]
836
+ else:
837
+ data = grp["data"][:]
838
+ # TODO: Implement lazy loading properly
839
+ warnings.warn(
840
+ "Lazy loading from Zarr not fully implemented yet. "
841
+ "Loading all data into memory.",
842
+ UserWarning,
843
+ )
844
+
845
+ # Load target name
846
+ target_name = grp.attrs.get("target_name", None)
847
+
848
+ return data, description, info_dict, target_name
849
+
850
+
851
+ def _create_compressor(compression, compression_level):
852
+ """Create a Zarr compressor object (zarr v2 API)."""
853
+ if zarr is False:
854
+ raise ImportError(
855
+ "Zarr is not installed. Install with: pip install braindecode[hub]"
856
+ )
857
+
858
+ if not NUMCODECS_AVAILABLE:
859
+ raise ImportError(
860
+ "numcodecs is not installed. Install with: pip install braindecode[hub]"
861
+ )
862
+
863
+ # Zarr v2 uses numcodecs compressors
864
+ if compression == "blosc":
865
+ return Blosc(cname="zstd", clevel=compression_level)
866
+ elif compression == "zstd":
867
+ return Zstd(level=compression_level)
868
+ elif compression == "gzip":
869
+ return GZip(level=compression_level)
870
+ else:
871
+ return None
872
+
873
+
874
+ # TODO: improve content
840
875
  def _generate_readme_content(
841
876
  format_info,
842
877
  n_recordings: int,
@@ -844,144 +879,84 @@ def _generate_readme_content(
844
879
  sfreq,
845
880
  data_type: str,
846
881
  n_windows: int,
847
- total_duration: float | None = None,
848
882
  format: str = "zarr",
849
883
  ):
850
- """Generate README.md content for a dataset uploaded to the Hub.
851
-
852
- Parameters
853
- ----------
854
- format_info : dict
855
- Dictionary containing format metadata (e.g., total_size_mb).
856
- n_recordings : int
857
- Number of recordings in the dataset.
858
- n_channels : int
859
- Number of EEG channels.
860
- sfreq : float or None
861
- Sampling frequency in Hz.
862
- data_type : str
863
- Type of dataset (e.g., "Windowed", "Continuous").
864
- n_windows : int
865
- Number of windows/samples in the dataset.
866
- total_duration : float or None
867
- Total duration in seconds across all recordings.
868
- format : str
869
- Storage format (default: "zarr").
870
-
871
- Returns
872
- -------
873
- str
874
- Markdown content for the README.md file.
875
- """
884
+ """Generate README.md content for a dataset uploaded to the Hub."""
885
+ # Use safe access for total size and format sfreq nicely
876
886
  total_size_mb = (
877
887
  format_info.get("total_size_mb", 0.0) if isinstance(format_info, dict) else 0.0
878
888
  )
879
889
  sfreq_str = f"{sfreq:g}" if sfreq is not None else "N/A"
880
890
 
881
- duration_str = (
882
- str(timedelta(seconds=int(total_duration))) if total_duration else "N/A"
883
- )
884
-
885
891
  return f"""---
886
892
  tags:
887
893
  - braindecode
888
894
  - eeg
889
895
  - neuroscience
890
896
  - brain-computer-interface
891
- - deep-learning
892
897
  license: unknown
893
898
  ---
894
899
 
895
900
  # EEG Dataset
896
901
 
897
- This dataset was created using [braindecode](https://braindecode.org), a deep
898
- learning library for EEG/MEG/ECoG signals.
902
+ This dataset was created using [braindecode](https://braindecode.org), a library for deep learning with EEG/MEG/ECoG signals.
899
903
 
900
904
  ## Dataset Information
901
905
 
902
906
  | Property | Value |
903
- |----------|------:|
904
- | Recordings | {n_recordings} |
905
- | Type | {data_type} |
906
- | Channels | {n_channels} |
907
+ |---|---:|
908
+ | Number of recordings | {n_recordings} |
909
+ | Dataset type | {data_type} |
910
+ | Number of channels | {n_channels} |
907
911
  | Sampling frequency | {sfreq_str} Hz |
908
- | Total duration | {duration_str} |
909
- | Windows/samples | {n_windows:,} |
910
- | Size | {total_size_mb:.2f} MB |
911
- | Format | {format} |
912
-
913
- ## Quick Start
914
-
915
- ```python
916
- from braindecode.datasets import BaseConcatDataset
917
-
918
- # Load from Hugging Face Hub
919
- dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
920
-
921
- # Access a sample
922
- X, y, metainfo = dataset[0]
923
- # X: EEG data [n_channels, n_times]
924
- # y: target label
925
- # metainfo: window indices
926
- ```
927
-
928
- ## Training with PyTorch
929
-
930
- ```python
931
- from torch.utils.data import DataLoader
932
-
933
- loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
934
-
935
- for X, y, metainfo in loader:
936
- # X: [batch_size, n_channels, n_times]
937
- # y: [batch_size]
938
- pass # Your training code
939
- ```
940
-
941
- ## BIDS-inspired Structure
942
-
943
- This dataset uses a **BIDS-inspired** organization. Metadata files follow BIDS
944
- conventions, while data is stored in Zarr format for efficient deep learning.
945
-
946
- **BIDS-style metadata:**
947
- - `dataset_description.json` - Dataset information
948
- - `participants.tsv` - Subject metadata
949
- - `*_events.tsv` - Trial/window events
950
- - `*_channels.tsv` - Channel information
951
- - `*_eeg.json` - Recording parameters
952
-
953
- **Data storage:**
954
- - `dataset.zarr/` - Zarr format (optimized for random access)
955
-
956
- ```
957
- sourcedata/braindecode/
958
- ├── dataset_description.json
959
- ├── participants.tsv
960
- ├── dataset.zarr/
961
- └── sub-<label>/
962
- └── eeg/
963
- ├── *_events.tsv
964
- ├── *_channels.tsv
965
- └── *_eeg.json
966
- ```
967
-
968
- ### Accessing Metadata
969
-
970
- ```python
971
- # Participants info
972
- if hasattr(dataset, "participants"):
973
- print(dataset.participants)
974
-
975
- # Events for a recording
976
- if hasattr(dataset.datasets[0], "bids_events"):
977
- print(dataset.datasets[0].bids_events)
978
-
979
- # Channel info
980
- if hasattr(dataset.datasets[0], "bids_channels"):
981
- print(dataset.datasets[0].bids_channels)
982
- ```
912
+ | Number of windows / samples | {n_windows} |
913
+ | Total size | {total_size_mb:.2f} MB |
914
+ | Storage format | {format} |
983
915
 
984
- ---
916
+ ## Usage
917
+
918
+ To load this dataset::
919
+
920
+ .. code-block:: python
921
+
922
+ from braindecode.datasets import BaseConcatDataset
923
+
924
+ # Load dataset from Hugging Face Hub
925
+ dataset = BaseConcatDataset.pull_from_hub("username/dataset-name")
926
+
927
+ # Access data
928
+ X, y, metainfo = dataset[0]
929
+ # X: EEG data (n_channels, n_times)
930
+ # y: label/target
931
+ # metainfo: window indices
932
+
933
+ ## Using with PyTorch DataLoader
934
+
935
+ ::
936
+
937
+ from torch.utils.data import DataLoader
938
+
939
+ # Create DataLoader for training
940
+ train_loader = DataLoader(
941
+ dataset,
942
+ batch_size=32,
943
+ shuffle=True,
944
+ num_workers=4
945
+ )
946
+
947
+ # Training loop
948
+ for X, y, metainfo in train_loader:
949
+ # X shape: [batch_size, n_channels, n_times]
950
+ # y shape: [batch_size]
951
+ # metainfo shape: [batch_size, 2] (start and end indices)
952
+ # Process your batch...
953
+
954
+ ## Dataset Format
955
+
956
+ This dataset is stored in **Zarr** format, optimized for:
957
+ - Fast random access during training (critical for PyTorch DataLoader)
958
+ - Efficient compression with blosc
959
+ - Cloud-native storage compatibility
985
960
 
986
- *Created with [braindecode](https://braindecode.org)*
961
+ For more information about braindecode, visit: https://braindecode.org
987
962
  """