braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/functional.py +0 -101
- braindecode/augmentation/transforms.py +0 -74
- braindecode/datasets/base.py +3 -18
- braindecode/datautil/serialization.py +1 -0
- braindecode/models/__init__.py +1 -8
- braindecode/models/summary.csv +0 -1
- braindecode/models/util.py +0 -84
- braindecode/preprocessing/__init__.py +0 -5
- braindecode/preprocessing/eegprep_preprocess.py +19 -134
- braindecode/preprocessing/mne_preprocess.py +25 -56
- braindecode/preprocessing/preprocess.py +41 -126
- braindecode/preprocessing/util.py +0 -11
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +5 -11
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/RECORD +19 -24
- braindecode/datasets/hub.py +0 -962
- braindecode/datasets/hub_validation.py +0 -113
- braindecode/datasets/registry.py +0 -120
- braindecode/datautil/hub_formats.py +0 -180
- braindecode/models/luna.py +0 -836
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Shared validation utilities for Hub format operations.
|
|
3
|
-
|
|
4
|
-
This module provides validation functions used by hub.py to avoid code duplication.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
# Authors: Kuntal Kokate
|
|
8
|
-
#
|
|
9
|
-
# License: BSD (3-clause)
|
|
10
|
-
|
|
11
|
-
from typing import Any, List, Tuple
|
|
12
|
-
|
|
13
|
-
from .registry import get_dataset_type
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def validate_dataset_uniformity(
|
|
17
|
-
datasets: List[Any],
|
|
18
|
-
) -> Tuple[str, List[str], float]:
|
|
19
|
-
"""
|
|
20
|
-
Validate all datasets have uniform type, channels, and sampling frequency.
|
|
21
|
-
|
|
22
|
-
Parameters
|
|
23
|
-
----------
|
|
24
|
-
datasets : list
|
|
25
|
-
List of dataset objects to validate.
|
|
26
|
-
|
|
27
|
-
Returns
|
|
28
|
-
-------
|
|
29
|
-
dataset_type : str
|
|
30
|
-
The validated dataset type (WindowsDataset, EEGWindowsDataset, or RawDataset).
|
|
31
|
-
first_ch_names : list of str
|
|
32
|
-
Channel names from the first dataset.
|
|
33
|
-
first_sfreq : float
|
|
34
|
-
Sampling frequency from the first dataset.
|
|
35
|
-
|
|
36
|
-
Raises
|
|
37
|
-
------
|
|
38
|
-
ValueError
|
|
39
|
-
If datasets have mixed types, inconsistent channels, or inconsistent
|
|
40
|
-
sampling frequencies.
|
|
41
|
-
TypeError
|
|
42
|
-
If dataset type is not supported.
|
|
43
|
-
"""
|
|
44
|
-
if not datasets:
|
|
45
|
-
raise ValueError("No datasets provided for validation.")
|
|
46
|
-
|
|
47
|
-
first_ds = datasets[0]
|
|
48
|
-
dataset_type = get_dataset_type(first_ds)
|
|
49
|
-
|
|
50
|
-
# Get reference channel names and sampling frequency from the first dataset
|
|
51
|
-
first_ch_names, first_sfreq = _get_ch_names_and_sfreq(first_ds, dataset_type)
|
|
52
|
-
|
|
53
|
-
# Validate all datasets have uniform properties
|
|
54
|
-
for i, ds in enumerate(datasets):
|
|
55
|
-
ds_type = get_dataset_type(ds)
|
|
56
|
-
if ds_type != dataset_type:
|
|
57
|
-
raise ValueError(
|
|
58
|
-
f"Mixed dataset types in concat: dataset 0 is {dataset_type} "
|
|
59
|
-
f"but dataset {i} is {ds_type}"
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
ch_names, sfreq = _get_ch_names_and_sfreq(ds, dataset_type)
|
|
63
|
-
|
|
64
|
-
if ch_names != first_ch_names:
|
|
65
|
-
raise ValueError(
|
|
66
|
-
f"Inconsistent channel names: dataset 0 has {first_ch_names} "
|
|
67
|
-
f"but dataset {i} has {ch_names}"
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
if sfreq != first_sfreq:
|
|
71
|
-
_raise_sfreq_error(first_sfreq, sfreq, i)
|
|
72
|
-
|
|
73
|
-
return dataset_type, first_ch_names, first_sfreq
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _get_ch_names_and_sfreq(ds: Any, dataset_type: str) -> Tuple[List[str], float]:
|
|
77
|
-
"""Return (ch_names, sfreq) for supported dataset types."""
|
|
78
|
-
if dataset_type == "WindowsDataset":
|
|
79
|
-
obj = ds.windows
|
|
80
|
-
elif dataset_type in ("EEGWindowsDataset", "RawDataset"):
|
|
81
|
-
obj = ds.raw
|
|
82
|
-
else:
|
|
83
|
-
raise TypeError(f"Unsupported dataset type: {dataset_type}")
|
|
84
|
-
|
|
85
|
-
return obj.ch_names, obj.info["sfreq"]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def _raise_sfreq_error(expected: float, actual: float, idx: int):
|
|
89
|
-
"""
|
|
90
|
-
Raise standardized sampling frequency error.
|
|
91
|
-
|
|
92
|
-
Parameters
|
|
93
|
-
----------
|
|
94
|
-
expected : float
|
|
95
|
-
Expected sampling frequency from dataset 0.
|
|
96
|
-
actual : float
|
|
97
|
-
Actual sampling frequency from current dataset.
|
|
98
|
-
idx : int
|
|
99
|
-
Index of the dataset with inconsistent sampling frequency.
|
|
100
|
-
|
|
101
|
-
Raises
|
|
102
|
-
------
|
|
103
|
-
ValueError
|
|
104
|
-
Always raised with standardized error message.
|
|
105
|
-
"""
|
|
106
|
-
raise ValueError(
|
|
107
|
-
f"Inconsistent sampling frequencies: dataset 0 has {expected} Hz "
|
|
108
|
-
f"but dataset {idx} has {actual} Hz. "
|
|
109
|
-
f"Please resample all datasets to a common frequency before saving. "
|
|
110
|
-
f"Use braindecode.preprocessing.preprocess("
|
|
111
|
-
f"[Preprocessor(Resample(sfreq={expected}))], concat_ds) "
|
|
112
|
-
f"to resample your datasets."
|
|
113
|
-
)
|
braindecode/datasets/registry.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Dataset registry for Hub integration.
|
|
3
|
-
|
|
4
|
-
Datasets register themselves here so Hub code can look them up by name
|
|
5
|
-
without direct imports (avoiding circular dependencies).
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
# Authors: Kuntal Kokate
|
|
9
|
-
#
|
|
10
|
-
# License: BSD (3-clause)
|
|
11
|
-
|
|
12
|
-
from typing import Any, Dict, Type
|
|
13
|
-
|
|
14
|
-
# Global registry mapping dataset class names to classes
|
|
15
|
-
_DATASET_REGISTRY: Dict[str, Type] = {}
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def register_dataset(cls: Type) -> Type:
|
|
19
|
-
"""
|
|
20
|
-
Decorator to register a dataset class in the global registry.
|
|
21
|
-
|
|
22
|
-
Parameters
|
|
23
|
-
----------
|
|
24
|
-
cls : Type
|
|
25
|
-
The dataset class to register.
|
|
26
|
-
|
|
27
|
-
Returns
|
|
28
|
-
-------
|
|
29
|
-
Type
|
|
30
|
-
The same class (unchanged), so this can be used as a decorator.
|
|
31
|
-
"""
|
|
32
|
-
_DATASET_REGISTRY[cls.__name__] = cls
|
|
33
|
-
return cls
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _available_datasets_str() -> str:
|
|
37
|
-
"""Return a human-readable list of registered dataset class names."""
|
|
38
|
-
if not _DATASET_REGISTRY:
|
|
39
|
-
return "<no registered datasets>"
|
|
40
|
-
return ", ".join(_DATASET_REGISTRY.keys())
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def get_dataset_class(name: str) -> Type:
|
|
44
|
-
"""
|
|
45
|
-
Retrieve a registered dataset class by name.
|
|
46
|
-
|
|
47
|
-
Parameters
|
|
48
|
-
----------
|
|
49
|
-
name : str
|
|
50
|
-
Name of the dataset class (e.g., 'WindowsDataset').
|
|
51
|
-
|
|
52
|
-
Returns
|
|
53
|
-
-------
|
|
54
|
-
Type
|
|
55
|
-
The dataset class.
|
|
56
|
-
|
|
57
|
-
Raises
|
|
58
|
-
------
|
|
59
|
-
KeyError
|
|
60
|
-
If the class name is not registered.
|
|
61
|
-
"""
|
|
62
|
-
try:
|
|
63
|
-
return _DATASET_REGISTRY[name]
|
|
64
|
-
except KeyError as exc:
|
|
65
|
-
raise KeyError(
|
|
66
|
-
f"Dataset class '{name}' not found in registry. "
|
|
67
|
-
f"Available classes: {_available_datasets_str()}"
|
|
68
|
-
) from exc
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def get_dataset_type(obj: Any) -> str:
|
|
72
|
-
"""
|
|
73
|
-
Get the registered type name for a dataset instance.
|
|
74
|
-
|
|
75
|
-
Parameters
|
|
76
|
-
----------
|
|
77
|
-
obj : Any
|
|
78
|
-
The object to check.
|
|
79
|
-
|
|
80
|
-
Returns
|
|
81
|
-
-------
|
|
82
|
-
str
|
|
83
|
-
The name of the dataset class (e.g., 'WindowsDataset').
|
|
84
|
-
|
|
85
|
-
Raises
|
|
86
|
-
------
|
|
87
|
-
TypeError
|
|
88
|
-
If the object is not an instance of any registered dataset class.
|
|
89
|
-
"""
|
|
90
|
-
for cls in _DATASET_REGISTRY.values():
|
|
91
|
-
if isinstance(obj, cls):
|
|
92
|
-
return cls.__name__
|
|
93
|
-
|
|
94
|
-
raise TypeError(
|
|
95
|
-
f"Object of type {type(obj).__name__} is not a registered dataset class. "
|
|
96
|
-
f"Available classes: {_available_datasets_str()}"
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def is_registered_dataset(obj: Any, class_name: str) -> bool:
|
|
101
|
-
"""
|
|
102
|
-
Check if an object is an instance of a registered dataset class.
|
|
103
|
-
|
|
104
|
-
Parameters
|
|
105
|
-
----------
|
|
106
|
-
obj : Any
|
|
107
|
-
The object to check.
|
|
108
|
-
class_name : str
|
|
109
|
-
Name of the dataset class to check against.
|
|
110
|
-
|
|
111
|
-
Returns
|
|
112
|
-
-------
|
|
113
|
-
bool
|
|
114
|
-
True if obj is an instance of the named class, False otherwise.
|
|
115
|
-
"""
|
|
116
|
-
try:
|
|
117
|
-
cls = get_dataset_class(class_name)
|
|
118
|
-
except KeyError:
|
|
119
|
-
return False
|
|
120
|
-
return isinstance(obj, cls)
|
|
@@ -1,180 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Format converters for Hugging Face Hub integration.
|
|
3
|
-
|
|
4
|
-
This module provides Zarr format converters to transform EEG datasets for
|
|
5
|
-
efficient storage and fast random access during training on the Hugging Face Hub.
|
|
6
|
-
|
|
7
|
-
This module provides a standalone functional API that delegates to the
|
|
8
|
-
HubDatasetMixin methods for all actual implementations.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
# Authors: Kuntal Kokate
|
|
12
|
-
#
|
|
13
|
-
# License: BSD (3-clause)
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import shutil
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
20
|
-
|
|
21
|
-
# Import registry for dynamic class lookup
|
|
22
|
-
from ..datasets.registry import get_dataset_class
|
|
23
|
-
|
|
24
|
-
# Import dataset classes for type checking only
|
|
25
|
-
if TYPE_CHECKING:
|
|
26
|
-
from ..datasets.base import BaseConcatDataset
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
# =============================================================================
|
|
30
|
-
# Zarr Format Converters
|
|
31
|
-
# =============================================================================
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def convert_to_zarr(
|
|
35
|
-
dataset: BaseConcatDataset,
|
|
36
|
-
output_path: Union[str, Path],
|
|
37
|
-
compression: str = "blosc",
|
|
38
|
-
compression_level: int = 5,
|
|
39
|
-
overwrite: bool = False,
|
|
40
|
-
) -> Path:
|
|
41
|
-
"""Convert BaseConcatDataset to Zarr format.
|
|
42
|
-
|
|
43
|
-
Zarr provides cloud-native chunked storage, optimized for random access
|
|
44
|
-
during training. This is the format used for Hugging Face Hub uploads,
|
|
45
|
-
based on comprehensive benchmarking showing:
|
|
46
|
-
- Fastest random access: 0.010 ms (critical for PyTorch DataLoader)
|
|
47
|
-
- Fast save/load: 0.46s / 0.12s
|
|
48
|
-
- Good compression: ~23% size reduction with blosc
|
|
49
|
-
|
|
50
|
-
Parameters
|
|
51
|
-
----------
|
|
52
|
-
dataset : BaseConcatDataset
|
|
53
|
-
The dataset to convert.
|
|
54
|
-
output_path : str | Path
|
|
55
|
-
Path where the Zarr directory will be created.
|
|
56
|
-
compression : str, default="blosc"
|
|
57
|
-
Compression algorithm. Options: "blosc" (recommended), "zstd", "gzip", None.
|
|
58
|
-
blosc uses zstd codec by default, providing best balance of speed and compression.
|
|
59
|
-
compression_level : int, default=5
|
|
60
|
-
Compression level (0-9). Level 5 provides optimal balance based on benchmarks.
|
|
61
|
-
overwrite : bool, default=False
|
|
62
|
-
Whether to overwrite existing directory.
|
|
63
|
-
|
|
64
|
-
Returns
|
|
65
|
-
-------
|
|
66
|
-
Path
|
|
67
|
-
Path to the created Zarr directory.
|
|
68
|
-
|
|
69
|
-
Notes
|
|
70
|
-
-----
|
|
71
|
-
The chunking strategy is optimized for random access:
|
|
72
|
-
- Windowed data: Each window is a separate chunk (1, n_channels, n_times)
|
|
73
|
-
- Raw data: Chunks of (n_channels, 10000) samples
|
|
74
|
-
|
|
75
|
-
Examples
|
|
76
|
-
--------
|
|
77
|
-
>>> dataset = NMT(path=path, preload=True)
|
|
78
|
-
>>> # Use default settings (optimal from benchmarks)
|
|
79
|
-
>>> zarr_path = convert_to_zarr(dataset, "dataset.zarr")
|
|
80
|
-
>>>
|
|
81
|
-
>>> # Or customize compression
|
|
82
|
-
>>> zarr_path = convert_to_zarr(
|
|
83
|
-
... dataset, "dataset.zarr",
|
|
84
|
-
... compression="blosc",
|
|
85
|
-
... compression_level=5
|
|
86
|
-
... )
|
|
87
|
-
"""
|
|
88
|
-
output_path = Path(output_path)
|
|
89
|
-
|
|
90
|
-
if output_path.exists():
|
|
91
|
-
if not overwrite:
|
|
92
|
-
raise FileExistsError(
|
|
93
|
-
f"{output_path} already exists. Set overwrite=True to replace it."
|
|
94
|
-
)
|
|
95
|
-
# Remove existing directory if overwrite is True
|
|
96
|
-
shutil.rmtree(output_path)
|
|
97
|
-
|
|
98
|
-
# Delegate to HubDatasetMixin method
|
|
99
|
-
dataset._convert_to_zarr_inline(output_path, compression, compression_level)
|
|
100
|
-
|
|
101
|
-
return output_path
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def load_from_zarr(
|
|
105
|
-
input_path: Union[str, Path],
|
|
106
|
-
preload: bool = True,
|
|
107
|
-
ids_to_load: Optional[List[int]] = None,
|
|
108
|
-
):
|
|
109
|
-
"""Load BaseConcatDataset from Zarr format.
|
|
110
|
-
|
|
111
|
-
Zarr is the format used for braindecode Hub datasets, providing
|
|
112
|
-
the fastest random access performance for training with PyTorch.
|
|
113
|
-
|
|
114
|
-
Parameters
|
|
115
|
-
----------
|
|
116
|
-
input_path : str | Path
|
|
117
|
-
Path to the Zarr directory.
|
|
118
|
-
preload : bool, default=True
|
|
119
|
-
Whether to load data into memory. If False, uses lazy loading
|
|
120
|
-
(data is loaded on-demand during training).
|
|
121
|
-
ids_to_load : list of int | None
|
|
122
|
-
Specific recording IDs to load. If None, loads all.
|
|
123
|
-
|
|
124
|
-
Returns
|
|
125
|
-
-------
|
|
126
|
-
BaseConcatDataset
|
|
127
|
-
The loaded dataset.
|
|
128
|
-
|
|
129
|
-
Examples
|
|
130
|
-
--------
|
|
131
|
-
>>> # Load from local zarr directory
|
|
132
|
-
>>> dataset = load_from_zarr("dataset.zarr", preload=True)
|
|
133
|
-
>>>
|
|
134
|
-
>>> # Load from Hugging Face Hub (handled automatically)
|
|
135
|
-
>>> from braindecode.datasets import BaseConcatDataset
|
|
136
|
-
>>> dataset = BaseConcatDataset.from_pretrained("username/dataset-name")
|
|
137
|
-
"""
|
|
138
|
-
# Delegate to HubDatasetMixin static method
|
|
139
|
-
BaseConcatDataset = get_dataset_class("BaseConcatDataset")
|
|
140
|
-
|
|
141
|
-
# Load full dataset using mixin method
|
|
142
|
-
dataset = BaseConcatDataset._load_from_zarr_inline(Path(input_path), preload)
|
|
143
|
-
|
|
144
|
-
# Filter to specific IDs if requested
|
|
145
|
-
if ids_to_load is not None:
|
|
146
|
-
# Get only the requested datasets
|
|
147
|
-
filtered_datasets = [dataset.datasets[i] for i in ids_to_load]
|
|
148
|
-
dataset = BaseConcatDataset(filtered_datasets)
|
|
149
|
-
|
|
150
|
-
return dataset
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# =============================================================================
|
|
154
|
-
# Utility Functions
|
|
155
|
-
# =============================================================================
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def get_format_info(dataset: "BaseConcatDataset") -> Dict:
|
|
159
|
-
"""Get dataset information for Hub metadata.
|
|
160
|
-
|
|
161
|
-
Validates that all datasets in the concat have uniform properties
|
|
162
|
-
(channels, sampling frequency) and raises an error if not.
|
|
163
|
-
|
|
164
|
-
Parameters
|
|
165
|
-
----------
|
|
166
|
-
dataset : BaseConcatDataset
|
|
167
|
-
The dataset to analyze.
|
|
168
|
-
|
|
169
|
-
Returns
|
|
170
|
-
-------
|
|
171
|
-
dict
|
|
172
|
-
Dictionary with dataset statistics and format info.
|
|
173
|
-
|
|
174
|
-
Raises
|
|
175
|
-
------
|
|
176
|
-
ValueError
|
|
177
|
-
If datasets have inconsistent channels or sampling frequencies.
|
|
178
|
-
"""
|
|
179
|
-
# Delegate to HubDatasetMixin method
|
|
180
|
-
return dataset._get_format_info_inline()
|