braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,113 +0,0 @@
1
- """
2
- Shared validation utilities for Hub format operations.
3
-
4
- This module provides validation functions used by hub.py to avoid code duplication.
5
- """
6
-
7
- # Authors: Kuntal Kokate
8
- #
9
- # License: BSD (3-clause)
10
-
11
- from typing import Any, List, Tuple
12
-
13
- from .registry import get_dataset_type
14
-
15
-
16
- def validate_dataset_uniformity(
17
- datasets: List[Any],
18
- ) -> Tuple[str, List[str], float]:
19
- """
20
- Validate all datasets have uniform type, channels, and sampling frequency.
21
-
22
- Parameters
23
- ----------
24
- datasets : list
25
- List of dataset objects to validate.
26
-
27
- Returns
28
- -------
29
- dataset_type : str
30
- The validated dataset type (WindowsDataset, EEGWindowsDataset, or RawDataset).
31
- first_ch_names : list of str
32
- Channel names from the first dataset.
33
- first_sfreq : float
34
- Sampling frequency from the first dataset.
35
-
36
- Raises
37
- ------
38
- ValueError
39
- If datasets have mixed types, inconsistent channels, or inconsistent
40
- sampling frequencies.
41
- TypeError
42
- If dataset type is not supported.
43
- """
44
- if not datasets:
45
- raise ValueError("No datasets provided for validation.")
46
-
47
- first_ds = datasets[0]
48
- dataset_type = get_dataset_type(first_ds)
49
-
50
- # Get reference channel names and sampling frequency from the first dataset
51
- first_ch_names, first_sfreq = _get_ch_names_and_sfreq(first_ds, dataset_type)
52
-
53
- # Validate all datasets have uniform properties
54
- for i, ds in enumerate(datasets):
55
- ds_type = get_dataset_type(ds)
56
- if ds_type != dataset_type:
57
- raise ValueError(
58
- f"Mixed dataset types in concat: dataset 0 is {dataset_type} "
59
- f"but dataset {i} is {ds_type}"
60
- )
61
-
62
- ch_names, sfreq = _get_ch_names_and_sfreq(ds, dataset_type)
63
-
64
- if ch_names != first_ch_names:
65
- raise ValueError(
66
- f"Inconsistent channel names: dataset 0 has {first_ch_names} "
67
- f"but dataset {i} has {ch_names}"
68
- )
69
-
70
- if sfreq != first_sfreq:
71
- _raise_sfreq_error(first_sfreq, sfreq, i)
72
-
73
- return dataset_type, first_ch_names, first_sfreq
74
-
75
-
76
- def _get_ch_names_and_sfreq(ds: Any, dataset_type: str) -> Tuple[List[str], float]:
77
- """Return (ch_names, sfreq) for supported dataset types."""
78
- if dataset_type == "WindowsDataset":
79
- obj = ds.windows
80
- elif dataset_type in ("EEGWindowsDataset", "RawDataset"):
81
- obj = ds.raw
82
- else:
83
- raise TypeError(f"Unsupported dataset type: {dataset_type}")
84
-
85
- return obj.ch_names, obj.info["sfreq"]
86
-
87
-
88
- def _raise_sfreq_error(expected: float, actual: float, idx: int):
89
- """
90
- Raise standardized sampling frequency error.
91
-
92
- Parameters
93
- ----------
94
- expected : float
95
- Expected sampling frequency from dataset 0.
96
- actual : float
97
- Actual sampling frequency from current dataset.
98
- idx : int
99
- Index of the dataset with inconsistent sampling frequency.
100
-
101
- Raises
102
- ------
103
- ValueError
104
- Always raised with standardized error message.
105
- """
106
- raise ValueError(
107
- f"Inconsistent sampling frequencies: dataset 0 has {expected} Hz "
108
- f"but dataset {idx} has {actual} Hz. "
109
- f"Please resample all datasets to a common frequency before saving. "
110
- f"Use braindecode.preprocessing.preprocess("
111
- f"[Preprocessor(Resample(sfreq={expected}))], concat_ds) "
112
- f"to resample your datasets."
113
- )
@@ -1,120 +0,0 @@
1
- """
2
- Dataset registry for Hub integration.
3
-
4
- Datasets register themselves here so Hub code can look them up by name
5
- without direct imports (avoiding circular dependencies).
6
- """
7
-
8
- # Authors: Kuntal Kokate
9
- #
10
- # License: BSD (3-clause)
11
-
12
- from typing import Any, Dict, Type
13
-
14
- # Global registry mapping dataset class names to classes
15
- _DATASET_REGISTRY: Dict[str, Type] = {}
16
-
17
-
18
- def register_dataset(cls: Type) -> Type:
19
- """
20
- Decorator to register a dataset class in the global registry.
21
-
22
- Parameters
23
- ----------
24
- cls : Type
25
- The dataset class to register.
26
-
27
- Returns
28
- -------
29
- Type
30
- The same class (unchanged), so this can be used as a decorator.
31
- """
32
- _DATASET_REGISTRY[cls.__name__] = cls
33
- return cls
34
-
35
-
36
- def _available_datasets_str() -> str:
37
- """Return a human-readable list of registered dataset class names."""
38
- if not _DATASET_REGISTRY:
39
- return "<no registered datasets>"
40
- return ", ".join(_DATASET_REGISTRY.keys())
41
-
42
-
43
- def get_dataset_class(name: str) -> Type:
44
- """
45
- Retrieve a registered dataset class by name.
46
-
47
- Parameters
48
- ----------
49
- name : str
50
- Name of the dataset class (e.g., 'WindowsDataset').
51
-
52
- Returns
53
- -------
54
- Type
55
- The dataset class.
56
-
57
- Raises
58
- ------
59
- KeyError
60
- If the class name is not registered.
61
- """
62
- try:
63
- return _DATASET_REGISTRY[name]
64
- except KeyError as exc:
65
- raise KeyError(
66
- f"Dataset class '{name}' not found in registry. "
67
- f"Available classes: {_available_datasets_str()}"
68
- ) from exc
69
-
70
-
71
- def get_dataset_type(obj: Any) -> str:
72
- """
73
- Get the registered type name for a dataset instance.
74
-
75
- Parameters
76
- ----------
77
- obj : Any
78
- The object to check.
79
-
80
- Returns
81
- -------
82
- str
83
- The name of the dataset class (e.g., 'WindowsDataset').
84
-
85
- Raises
86
- ------
87
- TypeError
88
- If the object is not an instance of any registered dataset class.
89
- """
90
- for cls in _DATASET_REGISTRY.values():
91
- if isinstance(obj, cls):
92
- return cls.__name__
93
-
94
- raise TypeError(
95
- f"Object of type {type(obj).__name__} is not a registered dataset class. "
96
- f"Available classes: {_available_datasets_str()}"
97
- )
98
-
99
-
100
- def is_registered_dataset(obj: Any, class_name: str) -> bool:
101
- """
102
- Check if an object is an instance of a registered dataset class.
103
-
104
- Parameters
105
- ----------
106
- obj : Any
107
- The object to check.
108
- class_name : str
109
- Name of the dataset class to check against.
110
-
111
- Returns
112
- -------
113
- bool
114
- True if obj is an instance of the named class, False otherwise.
115
- """
116
- try:
117
- cls = get_dataset_class(class_name)
118
- except KeyError:
119
- return False
120
- return isinstance(obj, cls)
@@ -1,180 +0,0 @@
1
- """
2
- Format converters for Hugging Face Hub integration.
3
-
4
- This module provides Zarr format converters to transform EEG datasets for
5
- efficient storage and fast random access during training on the Hugging Face Hub.
6
-
7
- This module provides a standalone functional API that delegates to the
8
- HubDatasetMixin methods for all actual implementations.
9
- """
10
-
11
- # Authors: Kuntal Kokate
12
- #
13
- # License: BSD (3-clause)
14
-
15
- from __future__ import annotations
16
-
17
- import shutil
18
- from pathlib import Path
19
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
20
-
21
- # Import registry for dynamic class lookup
22
- from ..datasets.registry import get_dataset_class
23
-
24
- # Import dataset classes for type checking only
25
- if TYPE_CHECKING:
26
- from ..datasets.base import BaseConcatDataset
27
-
28
-
29
- # =============================================================================
30
- # Zarr Format Converters
31
- # =============================================================================
32
-
33
-
34
- def convert_to_zarr(
35
- dataset: BaseConcatDataset,
36
- output_path: Union[str, Path],
37
- compression: str = "blosc",
38
- compression_level: int = 5,
39
- overwrite: bool = False,
40
- ) -> Path:
41
- """Convert BaseConcatDataset to Zarr format.
42
-
43
- Zarr provides cloud-native chunked storage, optimized for random access
44
- during training. This is the format used for Hugging Face Hub uploads,
45
- based on comprehensive benchmarking showing:
46
- - Fastest random access: 0.010 ms (critical for PyTorch DataLoader)
47
- - Fast save/load: 0.46s / 0.12s
48
- - Good compression: ~23% size reduction with blosc
49
-
50
- Parameters
51
- ----------
52
- dataset : BaseConcatDataset
53
- The dataset to convert.
54
- output_path : str | Path
55
- Path where the Zarr directory will be created.
56
- compression : str, default="blosc"
57
- Compression algorithm. Options: "blosc" (recommended), "zstd", "gzip", None.
58
- blosc uses zstd codec by default, providing best balance of speed and compression.
59
- compression_level : int, default=5
60
- Compression level (0-9). Level 5 provides optimal balance based on benchmarks.
61
- overwrite : bool, default=False
62
- Whether to overwrite existing directory.
63
-
64
- Returns
65
- -------
66
- Path
67
- Path to the created Zarr directory.
68
-
69
- Notes
70
- -----
71
- The chunking strategy is optimized for random access:
72
- - Windowed data: Each window is a separate chunk (1, n_channels, n_times)
73
- - Raw data: Chunks of (n_channels, 10000) samples
74
-
75
- Examples
76
- --------
77
- >>> dataset = NMT(path=path, preload=True)
78
- >>> # Use default settings (optimal from benchmarks)
79
- >>> zarr_path = convert_to_zarr(dataset, "dataset.zarr")
80
- >>>
81
- >>> # Or customize compression
82
- >>> zarr_path = convert_to_zarr(
83
- ... dataset, "dataset.zarr",
84
- ... compression="blosc",
85
- ... compression_level=5
86
- ... )
87
- """
88
- output_path = Path(output_path)
89
-
90
- if output_path.exists():
91
- if not overwrite:
92
- raise FileExistsError(
93
- f"{output_path} already exists. Set overwrite=True to replace it."
94
- )
95
- # Remove existing directory if overwrite is True
96
- shutil.rmtree(output_path)
97
-
98
- # Delegate to HubDatasetMixin method
99
- dataset._convert_to_zarr_inline(output_path, compression, compression_level)
100
-
101
- return output_path
102
-
103
-
104
- def load_from_zarr(
105
- input_path: Union[str, Path],
106
- preload: bool = True,
107
- ids_to_load: Optional[List[int]] = None,
108
- ):
109
- """Load BaseConcatDataset from Zarr format.
110
-
111
- Zarr is the format used for braindecode Hub datasets, providing
112
- the fastest random access performance for training with PyTorch.
113
-
114
- Parameters
115
- ----------
116
- input_path : str | Path
117
- Path to the Zarr directory.
118
- preload : bool, default=True
119
- Whether to load data into memory. If False, uses lazy loading
120
- (data is loaded on-demand during training).
121
- ids_to_load : list of int | None
122
- Specific recording IDs to load. If None, loads all.
123
-
124
- Returns
125
- -------
126
- BaseConcatDataset
127
- The loaded dataset.
128
-
129
- Examples
130
- --------
131
- >>> # Load from local zarr directory
132
- >>> dataset = load_from_zarr("dataset.zarr", preload=True)
133
- >>>
134
- >>> # Load from Hugging Face Hub (handled automatically)
135
- >>> from braindecode.datasets import BaseConcatDataset
136
- >>> dataset = BaseConcatDataset.from_pretrained("username/dataset-name")
137
- """
138
- # Delegate to HubDatasetMixin static method
139
- BaseConcatDataset = get_dataset_class("BaseConcatDataset")
140
-
141
- # Load full dataset using mixin method
142
- dataset = BaseConcatDataset._load_from_zarr_inline(Path(input_path), preload)
143
-
144
- # Filter to specific IDs if requested
145
- if ids_to_load is not None:
146
- # Get only the requested datasets
147
- filtered_datasets = [dataset.datasets[i] for i in ids_to_load]
148
- dataset = BaseConcatDataset(filtered_datasets)
149
-
150
- return dataset
151
-
152
-
153
- # =============================================================================
154
- # Utility Functions
155
- # =============================================================================
156
-
157
-
158
- def get_format_info(dataset: "BaseConcatDataset") -> Dict:
159
- """Get dataset information for Hub metadata.
160
-
161
- Validates that all datasets in the concat have uniform properties
162
- (channels, sampling frequency) and raises an error if not.
163
-
164
- Parameters
165
- ----------
166
- dataset : BaseConcatDataset
167
- The dataset to analyze.
168
-
169
- Returns
170
- -------
171
- dict
172
- Dictionary with dataset statistics and format info.
173
-
174
- Raises
175
- ------
176
- ValueError
177
- If datasets have inconsistent channels or sampling frequencies.
178
- """
179
- # Delegate to HubDatasetMixin method
180
- return dataset._get_format_info_inline()