braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
2
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
import mne
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
|
14
|
+
|
|
15
|
+
from .base import BaseConcatDataset, RawDataset
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_from_X_y(
|
|
21
|
+
X: NDArray,
|
|
22
|
+
y: ArrayLike,
|
|
23
|
+
drop_last_window: bool,
|
|
24
|
+
sfreq: float,
|
|
25
|
+
ch_names: ArrayLike = None,
|
|
26
|
+
window_size_samples: int | None = None,
|
|
27
|
+
window_stride_samples: int | None = None,
|
|
28
|
+
) -> BaseConcatDataset:
|
|
29
|
+
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for.
|
|
30
|
+
|
|
31
|
+
decoding with skorch and braindecode, where X is a list of pre-cut trials
|
|
32
|
+
and y are corresponding targets.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
X : array-like
|
|
37
|
+
list of pre-cut trials as n_trials x n_channels x n_times
|
|
38
|
+
y : array-like
|
|
39
|
+
targets corresponding to the trials
|
|
40
|
+
drop_last_window : bool
|
|
41
|
+
whether or not have a last overlapping window, when
|
|
42
|
+
windows/windows do not equally divide the continuous signal
|
|
43
|
+
sfreq : float
|
|
44
|
+
Sampling frequency of signals.
|
|
45
|
+
ch_names : array-like
|
|
46
|
+
Names of the channels.
|
|
47
|
+
window_size_samples : int
|
|
48
|
+
window size
|
|
49
|
+
window_stride_samples : int
|
|
50
|
+
stride between windows
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
windows_datasets : BaseConcatDataset
|
|
55
|
+
X and y transformed to a dataset format that is compatible with skorch
|
|
56
|
+
and braindecode
|
|
57
|
+
"""
|
|
58
|
+
# Prevent circular import
|
|
59
|
+
from ..preprocessing.windowers import (
|
|
60
|
+
create_fixed_length_windows,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
n_samples_per_x = []
|
|
64
|
+
base_datasets = []
|
|
65
|
+
if ch_names is None:
|
|
66
|
+
ch_names = [str(i) for i in range(X.shape[1])]
|
|
67
|
+
log.info(f"No channel names given, set to 0-{X.shape[1]}).")
|
|
68
|
+
|
|
69
|
+
for x, target in zip(X, y):
|
|
70
|
+
n_samples_per_x.append(x.shape[1])
|
|
71
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
|
|
72
|
+
raw = mne.io.RawArray(x, info)
|
|
73
|
+
base_dataset = RawDataset(
|
|
74
|
+
raw, pd.Series({"target": target}), target_name="target"
|
|
75
|
+
)
|
|
76
|
+
base_datasets.append(base_dataset)
|
|
77
|
+
base_datasets = BaseConcatDataset(base_datasets)
|
|
78
|
+
|
|
79
|
+
if window_size_samples is None and window_stride_samples is None:
|
|
80
|
+
if not len(np.unique(n_samples_per_x)) == 1:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"if 'window_size_samples' and "
|
|
83
|
+
"'window_stride_samples' are None, "
|
|
84
|
+
"all trials have to have the same length"
|
|
85
|
+
)
|
|
86
|
+
window_size_samples = n_samples_per_x[0]
|
|
87
|
+
window_stride_samples = n_samples_per_x[0]
|
|
88
|
+
windows_datasets = create_fixed_length_windows(
|
|
89
|
+
base_datasets,
|
|
90
|
+
start_offset_samples=0,
|
|
91
|
+
stop_offset_samples=None,
|
|
92
|
+
window_size_samples=window_size_samples,
|
|
93
|
+
window_stride_samples=window_stride_samples,
|
|
94
|
+
drop_last_window=drop_last_window,
|
|
95
|
+
)
|
|
96
|
+
return windows_datasets
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Utilities for data manipulation."""
|
|
2
|
+
|
|
3
|
+
from .channel_utils import (
|
|
4
|
+
division_channels_idx,
|
|
5
|
+
match_hemisphere_chans,
|
|
6
|
+
)
|
|
7
|
+
from .serialization import (
|
|
8
|
+
_check_save_dir_empty,
|
|
9
|
+
load_concat_dataset,
|
|
10
|
+
save_concat_dataset,
|
|
11
|
+
)
|
|
12
|
+
from .util import infer_signal_properties
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name):
|
|
16
|
+
# ideas from https://stackoverflow.com/a/57110249/1469195
|
|
17
|
+
import importlib
|
|
18
|
+
from warnings import warn
|
|
19
|
+
|
|
20
|
+
if name == "create_from_X_y":
|
|
21
|
+
warn(
|
|
22
|
+
"create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
|
|
23
|
+
)
|
|
24
|
+
xy = importlib.import_module("..datasets.xy", __package__)
|
|
25
|
+
return xy.create_from_X_y
|
|
26
|
+
if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
|
|
27
|
+
warn(
|
|
28
|
+
f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
|
|
29
|
+
)
|
|
30
|
+
mne = importlib.import_module("..datasets.mne", __package__)
|
|
31
|
+
return mne.__dict__[name]
|
|
32
|
+
if name in [
|
|
33
|
+
"scale",
|
|
34
|
+
"exponential_moving_demean",
|
|
35
|
+
"exponential_moving_standardize",
|
|
36
|
+
"filterbank",
|
|
37
|
+
"preprocess",
|
|
38
|
+
"Preprocessor",
|
|
39
|
+
]:
|
|
40
|
+
warn(
|
|
41
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
42
|
+
)
|
|
43
|
+
preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
|
|
44
|
+
return preprocess.__dict__[name]
|
|
45
|
+
if name in ["create_windows_from_events", "create_fixed_length_windows"]:
|
|
46
|
+
warn(
|
|
47
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
48
|
+
)
|
|
49
|
+
windowers = importlib.import_module("..preprocessing.windowers", __package__)
|
|
50
|
+
return windowers.__dict__[name]
|
|
51
|
+
|
|
52
|
+
raise AttributeError("No possible import named " + name)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
"load_concat_dataset",
|
|
57
|
+
"save_concat_dataset",
|
|
58
|
+
"_check_save_dir_empty",
|
|
59
|
+
"match_hemisphere_chans",
|
|
60
|
+
"division_channels_idx",
|
|
61
|
+
"infer_signal_properties",
|
|
62
|
+
]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for EEG channel manipulation and selection.
|
|
3
|
+
|
|
4
|
+
This module provides functions for dividing and matching EEG channels,
|
|
5
|
+
particularly for hemisphere-aware processing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from re import search
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def match_hemisphere_chans(left_chs, right_chs):
|
|
13
|
+
"""
|
|
14
|
+
Match channels of the left and right hemispheres based on their names.
|
|
15
|
+
|
|
16
|
+
This function pairs channels from the left and right hemispheres by matching
|
|
17
|
+
their numeric identifiers. For a left channel with number N, it finds the
|
|
18
|
+
corresponding right channel with number N+1.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
left_chs : list of str
|
|
23
|
+
A list of channel names from the left hemisphere.
|
|
24
|
+
right_chs : list of str
|
|
25
|
+
A list of channel names from the right hemisphere.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
list of tuples
|
|
30
|
+
List of tuples with matched channel names from the left and right hemispheres.
|
|
31
|
+
Each tuple contains (left_channel, right_channel).
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
ValueError
|
|
36
|
+
If the left and right channels do not match in length.
|
|
37
|
+
ValueError
|
|
38
|
+
If a channel name does not contain a number.
|
|
39
|
+
ValueError
|
|
40
|
+
If no matching right hemisphere channel is found for a left channel.
|
|
41
|
+
|
|
42
|
+
Examples
|
|
43
|
+
--------
|
|
44
|
+
>>> left = ['C3', 'F3']
|
|
45
|
+
>>> right = ['C4', 'F4']
|
|
46
|
+
>>> match_hemisphere_chans(left, right)
|
|
47
|
+
[('C3', 'C4'), ('F3', 'F4')]
|
|
48
|
+
"""
|
|
49
|
+
if len(left_chs) != len(right_chs):
|
|
50
|
+
raise ValueError("Left and right channels do not match.")
|
|
51
|
+
right_chs = list(right_chs)
|
|
52
|
+
regexp = r"\d+"
|
|
53
|
+
out = []
|
|
54
|
+
for left in left_chs:
|
|
55
|
+
match = re.search(regexp, left)
|
|
56
|
+
if match is None:
|
|
57
|
+
raise ValueError(f"Channel '{left}' does not contain a number.")
|
|
58
|
+
chan_idx = 1 + int(match.group())
|
|
59
|
+
target_r = re.sub(regexp, str(chan_idx), left)
|
|
60
|
+
for right in right_chs:
|
|
61
|
+
if right == target_r:
|
|
62
|
+
out.append((left, right))
|
|
63
|
+
right_chs.remove(right)
|
|
64
|
+
break
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Found no right hemisphere matching channel for '{left}'."
|
|
68
|
+
)
|
|
69
|
+
return out
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def division_channels_idx(ch_names):
|
|
73
|
+
"""
|
|
74
|
+
Divide EEG channel names into left, right, and middle based on numbering.
|
|
75
|
+
|
|
76
|
+
This function categorizes channels by their numeric suffix:
|
|
77
|
+
- Odd-numbered channels → left hemisphere
|
|
78
|
+
- Even-numbered channels → right hemisphere
|
|
79
|
+
- Channels without numbers → middle/midline
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
ch_names : list of str
|
|
84
|
+
A list of EEG channel names to be divided based on their numbering.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
tuple of lists
|
|
89
|
+
Three lists containing the channel names:
|
|
90
|
+
- left: Odd-numbered channels (e.g., C3, F3, P3)
|
|
91
|
+
- right: Even-numbered channels (e.g., C4, F4, P4)
|
|
92
|
+
- middle: Channels without numbers (e.g., Cz, Fz, Pz)
|
|
93
|
+
|
|
94
|
+
Notes
|
|
95
|
+
-----
|
|
96
|
+
The function identifies channel numbers by searching for numeric characters
|
|
97
|
+
in the channel names. Standard 10-20 system EEG channel naming conventions
|
|
98
|
+
use odd numbers for left hemisphere and even numbers for right hemisphere.
|
|
99
|
+
|
|
100
|
+
Examples
|
|
101
|
+
--------
|
|
102
|
+
>>> channels = ['FP1', 'FP2', 'O1', 'O2', 'FZ']
|
|
103
|
+
>>> division_channels_idx(channels)
|
|
104
|
+
(['FP1', 'O1'], ['FP2', 'O2'], ['FZ'])
|
|
105
|
+
"""
|
|
106
|
+
left, right, middle = [], [], []
|
|
107
|
+
for ch in ch_names:
|
|
108
|
+
number = search(r"\d+", ch)
|
|
109
|
+
if number is not None:
|
|
110
|
+
(left if int(number[0]) % 2 else right).append(ch)
|
|
111
|
+
else:
|
|
112
|
+
middle.append(ch)
|
|
113
|
+
|
|
114
|
+
return left, right, middle
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Format converters for Hugging Face Hub integration.
|
|
3
|
+
|
|
4
|
+
This module provides Zarr format converters to transform EEG datasets for
|
|
5
|
+
efficient storage and fast random access during training on the Hugging Face Hub.
|
|
6
|
+
|
|
7
|
+
This module provides a standalone functional API that delegates to the
|
|
8
|
+
HubDatasetMixin methods for all actual implementations.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# Authors: Kuntal Kokate
|
|
12
|
+
#
|
|
13
|
+
# License: BSD (3-clause)
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import shutil
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
20
|
+
|
|
21
|
+
# Import registry for dynamic class lookup
|
|
22
|
+
from ..datasets.registry import get_dataset_class
|
|
23
|
+
|
|
24
|
+
# Import dataset classes for type checking only
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from ..datasets.base import BaseConcatDataset
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Zarr Format Converters
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def convert_to_zarr(
|
|
35
|
+
dataset: BaseConcatDataset,
|
|
36
|
+
output_path: Union[str, Path],
|
|
37
|
+
compression: str = "blosc",
|
|
38
|
+
compression_level: int = 5,
|
|
39
|
+
overwrite: bool = False,
|
|
40
|
+
) -> Path:
|
|
41
|
+
"""Convert BaseConcatDataset to Zarr format.
|
|
42
|
+
|
|
43
|
+
Zarr provides cloud-native chunked storage, optimized for random access
|
|
44
|
+
during training. This is the format used for Hugging Face Hub uploads,
|
|
45
|
+
based on comprehensive benchmarking showing:
|
|
46
|
+
- Fastest random access: 0.010 ms (critical for PyTorch DataLoader)
|
|
47
|
+
- Fast save/load: 0.46s / 0.12s
|
|
48
|
+
- Good compression: ~23% size reduction with blosc
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
dataset : BaseConcatDataset
|
|
53
|
+
The dataset to convert.
|
|
54
|
+
output_path : str | Path
|
|
55
|
+
Path where the Zarr directory will be created.
|
|
56
|
+
compression : str, default="blosc"
|
|
57
|
+
Compression algorithm. Options: "blosc" (recommended), "zstd", "gzip", None.
|
|
58
|
+
blosc uses zstd codec by default, providing best balance of speed and compression.
|
|
59
|
+
compression_level : int, default=5
|
|
60
|
+
Compression level (0-9). Level 5 provides optimal balance based on benchmarks.
|
|
61
|
+
overwrite : bool, default=False
|
|
62
|
+
Whether to overwrite existing directory.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
Path
|
|
67
|
+
Path to the created Zarr directory.
|
|
68
|
+
|
|
69
|
+
Notes
|
|
70
|
+
-----
|
|
71
|
+
The chunking strategy is optimized for random access:
|
|
72
|
+
- Windowed data: Each window is a separate chunk (1, n_channels, n_times)
|
|
73
|
+
- Raw data: Chunks of (n_channels, 10000) samples
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
>>> dataset = NMT(path=path, preload=True)
|
|
78
|
+
>>> # Use default settings (optimal from benchmarks)
|
|
79
|
+
>>> zarr_path = convert_to_zarr(dataset, "dataset.zarr")
|
|
80
|
+
>>>
|
|
81
|
+
>>> # Or customize compression
|
|
82
|
+
>>> zarr_path = convert_to_zarr(
|
|
83
|
+
... dataset, "dataset.zarr",
|
|
84
|
+
... compression="blosc",
|
|
85
|
+
... compression_level=5
|
|
86
|
+
... )
|
|
87
|
+
"""
|
|
88
|
+
output_path = Path(output_path)
|
|
89
|
+
|
|
90
|
+
if output_path.exists():
|
|
91
|
+
if not overwrite:
|
|
92
|
+
raise FileExistsError(
|
|
93
|
+
f"{output_path} already exists. Set overwrite=True to replace it."
|
|
94
|
+
)
|
|
95
|
+
# Remove existing directory if overwrite is True
|
|
96
|
+
shutil.rmtree(output_path)
|
|
97
|
+
|
|
98
|
+
# Delegate to HubDatasetMixin method
|
|
99
|
+
dataset._convert_to_zarr_inline(output_path, compression, compression_level)
|
|
100
|
+
|
|
101
|
+
return output_path
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def load_from_zarr(
|
|
105
|
+
input_path: Union[str, Path],
|
|
106
|
+
preload: bool = True,
|
|
107
|
+
ids_to_load: Optional[List[int]] = None,
|
|
108
|
+
):
|
|
109
|
+
"""Load BaseConcatDataset from Zarr format.
|
|
110
|
+
|
|
111
|
+
Zarr is the format used for braindecode Hub datasets, providing
|
|
112
|
+
the fastest random access performance for training with PyTorch.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
input_path : str | Path
|
|
117
|
+
Path to the Zarr directory.
|
|
118
|
+
preload : bool, default=True
|
|
119
|
+
Whether to load data into memory. If False, uses lazy loading
|
|
120
|
+
(data is loaded on-demand during training).
|
|
121
|
+
ids_to_load : list of int | None
|
|
122
|
+
Specific recording IDs to load. If None, loads all.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
BaseConcatDataset
|
|
127
|
+
The loaded dataset.
|
|
128
|
+
|
|
129
|
+
Examples
|
|
130
|
+
--------
|
|
131
|
+
>>> # Load from local zarr directory
|
|
132
|
+
>>> dataset = load_from_zarr("dataset.zarr", preload=True)
|
|
133
|
+
>>>
|
|
134
|
+
>>> # Load from Hugging Face Hub (handled automatically)
|
|
135
|
+
>>> from braindecode.datasets import BaseConcatDataset
|
|
136
|
+
>>> dataset = BaseConcatDataset.from_pretrained("username/dataset-name")
|
|
137
|
+
"""
|
|
138
|
+
# Delegate to HubDatasetMixin static method
|
|
139
|
+
BaseConcatDataset = get_dataset_class("BaseConcatDataset")
|
|
140
|
+
|
|
141
|
+
# Load full dataset using mixin method
|
|
142
|
+
dataset = BaseConcatDataset._load_from_zarr_inline(Path(input_path), preload)
|
|
143
|
+
|
|
144
|
+
# Filter to specific IDs if requested
|
|
145
|
+
if ids_to_load is not None:
|
|
146
|
+
# Get only the requested datasets
|
|
147
|
+
filtered_datasets = [dataset.datasets[i] for i in ids_to_load]
|
|
148
|
+
dataset = BaseConcatDataset(filtered_datasets)
|
|
149
|
+
|
|
150
|
+
return dataset
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# =============================================================================
|
|
154
|
+
# Utility Functions
|
|
155
|
+
# =============================================================================
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_format_info(dataset: "BaseConcatDataset") -> Dict:
|
|
159
|
+
"""Get dataset information for Hub metadata.
|
|
160
|
+
|
|
161
|
+
Validates that all datasets in the concat have uniform properties
|
|
162
|
+
(channels, sampling frequency) and raises an error if not.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
dataset : BaseConcatDataset
|
|
167
|
+
The dataset to analyze.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
dict
|
|
172
|
+
Dictionary with dataset statistics and format info.
|
|
173
|
+
|
|
174
|
+
Raises
|
|
175
|
+
------
|
|
176
|
+
ValueError
|
|
177
|
+
If datasets have inconsistent channels or sampling frequencies.
|
|
178
|
+
"""
|
|
179
|
+
# Delegate to HubDatasetMixin method
|
|
180
|
+
return dataset._get_format_info_inline()
|