braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
"""
|
|
3
|
+
Low-level Zarr I/O helpers for Hub integration.
|
|
4
|
+
|
|
5
|
+
These functions keep the Zarr serialization details isolated from hub.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import warnings
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from mne.utils import _soft_import
|
|
17
|
+
|
|
18
|
+
zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _sanitize_for_json(obj):
|
|
22
|
+
"""Replace NaN/Inf with None for valid JSON."""
|
|
23
|
+
if isinstance(obj, float):
|
|
24
|
+
if np.isnan(obj) or np.isinf(obj):
|
|
25
|
+
return None
|
|
26
|
+
return obj
|
|
27
|
+
if isinstance(obj, dict):
|
|
28
|
+
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
|
29
|
+
if isinstance(obj, list):
|
|
30
|
+
return [_sanitize_for_json(v) for v in obj]
|
|
31
|
+
if isinstance(obj, np.ndarray):
|
|
32
|
+
return _sanitize_for_json(obj.tolist())
|
|
33
|
+
return obj
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _restore_nan_from_json(obj):
|
|
37
|
+
"""Restore NaN values from None in JSON-loaded data."""
|
|
38
|
+
if isinstance(obj, dict):
|
|
39
|
+
return {k: _restore_nan_from_json(v) for k, v in obj.items()}
|
|
40
|
+
if isinstance(obj, list):
|
|
41
|
+
if len(obj) > 0 and all(isinstance(x, (int, float, type(None))) for x in obj):
|
|
42
|
+
return [np.nan if x is None else x for x in obj]
|
|
43
|
+
return [_restore_nan_from_json(v) for v in obj]
|
|
44
|
+
return obj
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _save_windows_to_zarr(
|
|
48
|
+
grp, data, metadata, description, info, compressor, target_name
|
|
49
|
+
):
|
|
50
|
+
"""Save windowed data to Zarr group (low-level function)."""
|
|
51
|
+
data_array = data.astype(np.float32)
|
|
52
|
+
compressors_list = [compressor] if compressor is not None else None
|
|
53
|
+
|
|
54
|
+
grp.create_array(
|
|
55
|
+
"data",
|
|
56
|
+
data=data_array,
|
|
57
|
+
chunks=(1, data_array.shape[1], data_array.shape[2]),
|
|
58
|
+
compressors=compressors_list,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
|
|
62
|
+
metadata_path = Path(store_path) / grp.path / "metadata.tsv"
|
|
63
|
+
metadata.to_csv(metadata_path, sep="\t", index=True)
|
|
64
|
+
|
|
65
|
+
grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
|
|
66
|
+
grp.attrs["info"] = _sanitize_for_json(info)
|
|
67
|
+
|
|
68
|
+
if target_name is not None:
|
|
69
|
+
grp.attrs["target_name"] = target_name
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _save_eegwindows_to_zarr(
|
|
73
|
+
grp, raw, metadata, description, info, targets_from, last_target_only, compressor
|
|
74
|
+
):
|
|
75
|
+
"""Save EEG continuous raw data to Zarr group (low-level function)."""
|
|
76
|
+
continuous_data = raw.get_data()
|
|
77
|
+
continuous_float = continuous_data.astype(np.float32)
|
|
78
|
+
compressors_list = [compressor] if compressor is not None else None
|
|
79
|
+
|
|
80
|
+
grp.create_array(
|
|
81
|
+
"data",
|
|
82
|
+
data=continuous_float,
|
|
83
|
+
chunks=(continuous_float.shape[0], min(10000, continuous_float.shape[1])),
|
|
84
|
+
compressors=compressors_list,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
|
|
88
|
+
metadata_path = Path(store_path) / grp.path / "metadata.tsv"
|
|
89
|
+
metadata.to_csv(metadata_path, sep="\t", index=True)
|
|
90
|
+
|
|
91
|
+
grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
|
|
92
|
+
grp.attrs["info"] = _sanitize_for_json(info)
|
|
93
|
+
grp.attrs["targets_from"] = targets_from
|
|
94
|
+
grp.attrs["last_target_only"] = last_target_only
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _load_windows_from_zarr(grp, preload):
|
|
98
|
+
"""Load windowed data from Zarr group (low-level function)."""
|
|
99
|
+
store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
|
|
100
|
+
metadata_path = Path(store_path) / grp.path / "metadata.tsv"
|
|
101
|
+
metadata = pd.read_csv(metadata_path, sep="\t", index_col=0)
|
|
102
|
+
|
|
103
|
+
description = pd.Series(grp.attrs["description"])
|
|
104
|
+
info_dict = _restore_nan_from_json(grp.attrs["info"])
|
|
105
|
+
|
|
106
|
+
if preload:
|
|
107
|
+
data = grp["data"][:]
|
|
108
|
+
else:
|
|
109
|
+
data = grp["data"][:]
|
|
110
|
+
warnings.warn(
|
|
111
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
112
|
+
"Loading all data into memory.",
|
|
113
|
+
UserWarning,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
target_name = grp.attrs.get("target_name", None)
|
|
117
|
+
|
|
118
|
+
return data, metadata, description, info_dict, target_name
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _load_eegwindows_from_zarr(grp, preload):
|
|
122
|
+
"""Load EEG continuous raw data from Zarr group (low-level function)."""
|
|
123
|
+
store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
|
|
124
|
+
metadata_path = Path(store_path) / grp.path / "metadata.tsv"
|
|
125
|
+
metadata = pd.read_csv(metadata_path, sep="\t", index_col=0)
|
|
126
|
+
|
|
127
|
+
description = pd.Series(grp.attrs["description"])
|
|
128
|
+
info_dict = _restore_nan_from_json(grp.attrs["info"])
|
|
129
|
+
|
|
130
|
+
if preload:
|
|
131
|
+
data = grp["data"][:]
|
|
132
|
+
else:
|
|
133
|
+
data = grp["data"][:]
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
136
|
+
"Loading all data into memory.",
|
|
137
|
+
UserWarning,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
targets_from = grp.attrs.get("targets_from", "metadata")
|
|
141
|
+
last_target_only = grp.attrs.get("last_target_only", True)
|
|
142
|
+
|
|
143
|
+
return data, metadata, description, info_dict, targets_from, last_target_only
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _save_raw_to_zarr(grp, raw, description, info, target_name, compressor):
|
|
147
|
+
"""Save RawDataset continuous raw data to Zarr group (low-level function)."""
|
|
148
|
+
continuous_data = raw.get_data()
|
|
149
|
+
continuous_float = continuous_data.astype(np.float32)
|
|
150
|
+
compressors_list = [compressor] if compressor is not None else None
|
|
151
|
+
|
|
152
|
+
grp.create_array(
|
|
153
|
+
"data",
|
|
154
|
+
data=continuous_float,
|
|
155
|
+
chunks=(continuous_float.shape[0], min(10000, continuous_float.shape[1])),
|
|
156
|
+
compressors=compressors_list,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
|
|
160
|
+
grp.attrs["info"] = _sanitize_for_json(info)
|
|
161
|
+
|
|
162
|
+
if target_name is not None:
|
|
163
|
+
grp.attrs["target_name"] = target_name
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _load_raw_from_zarr(grp, preload):
|
|
167
|
+
"""Load RawDataset continuous raw data from Zarr group (low-level function)."""
|
|
168
|
+
description = pd.Series(grp.attrs["description"])
|
|
169
|
+
info_dict = _restore_nan_from_json(grp.attrs["info"])
|
|
170
|
+
|
|
171
|
+
if preload:
|
|
172
|
+
data = grp["data"][:]
|
|
173
|
+
else:
|
|
174
|
+
data = grp["data"][:]
|
|
175
|
+
warnings.warn(
|
|
176
|
+
"Lazy loading from Zarr not fully implemented yet. "
|
|
177
|
+
"Loading all data into memory.",
|
|
178
|
+
UserWarning,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
target_name = grp.attrs.get("target_name", None)
|
|
182
|
+
|
|
183
|
+
return data, description, info_dict, target_name
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _create_compressor(compression, compression_level):
|
|
187
|
+
"""Create a Zarr v3 compressor codec."""
|
|
188
|
+
if zarr is False:
|
|
189
|
+
raise ImportError(
|
|
190
|
+
"Zarr is not installed. Install with: pip install braindecode[hub]"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if compression is None or compression not in ("blosc", "zstd", "gzip"):
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
name = "zstd" if compression == "blosc" else compression
|
|
197
|
+
return {"name": name, "configuration": {"level": compression_level}}
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
"""
|
|
3
|
+
Shared validation utilities for Hub format operations.
|
|
4
|
+
|
|
5
|
+
This module provides validation functions used by hub.py to avoid code duplication.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Authors: Kuntal Kokate
|
|
9
|
+
#
|
|
10
|
+
# License: BSD (3-clause)
|
|
11
|
+
|
|
12
|
+
from typing import Any, List, Tuple
|
|
13
|
+
|
|
14
|
+
from ..registry import get_dataset_type
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def validate_dataset_uniformity(
|
|
18
|
+
datasets: List[Any],
|
|
19
|
+
) -> Tuple[str, List[str], float]:
|
|
20
|
+
"""
|
|
21
|
+
Validate all datasets have uniform type, channels, and sampling frequency.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
datasets : list
|
|
26
|
+
List of dataset objects to validate.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
dataset_type : str
|
|
31
|
+
The validated dataset type (WindowsDataset, EEGWindowsDataset, or RawDataset).
|
|
32
|
+
first_ch_names : list of str
|
|
33
|
+
Channel names from the first dataset.
|
|
34
|
+
first_sfreq : float
|
|
35
|
+
Sampling frequency from the first dataset.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
If datasets have mixed types, inconsistent channels, or inconsistent
|
|
41
|
+
sampling frequencies.
|
|
42
|
+
TypeError
|
|
43
|
+
If dataset type is not supported.
|
|
44
|
+
"""
|
|
45
|
+
if not datasets:
|
|
46
|
+
raise ValueError("No datasets provided for validation.")
|
|
47
|
+
|
|
48
|
+
first_ds = datasets[0]
|
|
49
|
+
dataset_type = get_dataset_type(first_ds)
|
|
50
|
+
|
|
51
|
+
# Get reference channel names and sampling frequency from the first dataset
|
|
52
|
+
first_ch_names, first_sfreq = _get_ch_names_and_sfreq(first_ds, dataset_type)
|
|
53
|
+
|
|
54
|
+
# Validate all datasets have uniform properties
|
|
55
|
+
for i, ds in enumerate(datasets):
|
|
56
|
+
ds_type = get_dataset_type(ds)
|
|
57
|
+
if ds_type != dataset_type:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Mixed dataset types in concat: dataset 0 is {dataset_type} "
|
|
60
|
+
f"but dataset {i} is {ds_type}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
ch_names, sfreq = _get_ch_names_and_sfreq(ds, dataset_type)
|
|
64
|
+
|
|
65
|
+
if ch_names != first_ch_names:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Inconsistent channel names: dataset 0 has {first_ch_names} "
|
|
68
|
+
f"but dataset {i} has {ch_names}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if sfreq != first_sfreq:
|
|
72
|
+
_raise_sfreq_error(first_sfreq, sfreq, i)
|
|
73
|
+
|
|
74
|
+
return dataset_type, first_ch_names, first_sfreq
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _get_ch_names_and_sfreq(ds: Any, dataset_type: str) -> Tuple[List[str], float]:
|
|
78
|
+
"""Return (ch_names, sfreq) for supported dataset types."""
|
|
79
|
+
if dataset_type == "WindowsDataset":
|
|
80
|
+
obj = ds.windows
|
|
81
|
+
elif dataset_type in ("EEGWindowsDataset", "RawDataset"):
|
|
82
|
+
obj = ds.raw
|
|
83
|
+
else:
|
|
84
|
+
raise TypeError(f"Unsupported dataset type: {dataset_type}")
|
|
85
|
+
|
|
86
|
+
return obj.ch_names, obj.info["sfreq"]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _raise_sfreq_error(expected: float, actual: float, idx: int):
|
|
90
|
+
"""
|
|
91
|
+
Raise standardized sampling frequency error.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
expected : float
|
|
96
|
+
Expected sampling frequency from dataset 0.
|
|
97
|
+
actual : float
|
|
98
|
+
Actual sampling frequency from current dataset.
|
|
99
|
+
idx : int
|
|
100
|
+
Index of the dataset with inconsistent sampling frequency.
|
|
101
|
+
|
|
102
|
+
Raises
|
|
103
|
+
------
|
|
104
|
+
ValueError
|
|
105
|
+
Always raised with standardized error message.
|
|
106
|
+
"""
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Inconsistent sampling frequencies: dataset 0 has {expected} Hz "
|
|
109
|
+
f"but dataset {idx} has {actual} Hz. "
|
|
110
|
+
f"Please resample all datasets to a common frequency before saving. "
|
|
111
|
+
f"Use braindecode.preprocessing.preprocess("
|
|
112
|
+
f"[Preprocessor(Resample(sfreq={expected}))], concat_ds) "
|
|
113
|
+
f"to resample your datasets."
|
|
114
|
+
)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import mne_bids
|
|
8
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BIDSIterableDataset(IterableDataset):
|
|
12
|
+
"""Dataset for loading BIDS.
|
|
13
|
+
|
|
14
|
+
.. warning::
|
|
15
|
+
This class is experimental and may change in the future.
|
|
16
|
+
|
|
17
|
+
.. warning::
|
|
18
|
+
This dataset is not consistent with the Braindecode API.
|
|
19
|
+
|
|
20
|
+
This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
|
|
21
|
+
as it will be used to find the files to load. The default ``extensions`` parameter was changed.
|
|
22
|
+
|
|
23
|
+
More information on BIDS (Brain Imaging Data Structure)
|
|
24
|
+
can be found at https://bids.neuroimaging.io
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
>>> from braindecode.datasets import BaseConcatDataset, RawDataset, RecordDataset
|
|
29
|
+
>>> from braindecode.datasets.bids import BIDSIterableDataset
|
|
30
|
+
>>> from braindecode.preprocessing import create_fixed_length_windows
|
|
31
|
+
>>>
|
|
32
|
+
>>> def my_reader_fn(path):
|
|
33
|
+
... raw = mne_bids.read_raw_bids(path)
|
|
34
|
+
... ds: RecordDataset = RawDataset(raw, description={"path": path.fpath})
|
|
35
|
+
... windows_ds = create_fixed_length_windows(
|
|
36
|
+
... BaseConcatDataset([ds]),
|
|
37
|
+
... window_size_samples=400,
|
|
38
|
+
... window_stride_samples=200,
|
|
39
|
+
... )
|
|
40
|
+
... return windows_ds
|
|
41
|
+
>>>
|
|
42
|
+
>>> dataset = BIDSIterableDataset(
|
|
43
|
+
... reader_fn=my_reader_fn,
|
|
44
|
+
... root="root/of/my/bids/dataset/",
|
|
45
|
+
... )
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
|
|
50
|
+
A function that takes a BIDSPath and returns a dataset (e.g., a
|
|
51
|
+
RecordDataset or BaseConcatDataset of RecordDataset).
|
|
52
|
+
pool_size : int
|
|
53
|
+
The number of recordings to read and sample from.
|
|
54
|
+
bids_paths : list[mne_bids.BIDSPath] | None
|
|
55
|
+
A list of BIDSPaths to load. If None, will use the paths found by
|
|
56
|
+
:func:`mne_bids.find_matching_paths` and the arguments below.
|
|
57
|
+
root : pathlib.Path | str
|
|
58
|
+
The root of the BIDS path.
|
|
59
|
+
subjects : str | array-like of str | None
|
|
60
|
+
The subject ID. Corresponds to "sub".
|
|
61
|
+
sessions : str | array-like of str | None
|
|
62
|
+
The acquisition session. Corresponds to "ses".
|
|
63
|
+
tasks : str | array-like of str | None
|
|
64
|
+
The experimental task. Corresponds to "task".
|
|
65
|
+
acquisitions : str | array-like of str | None
|
|
66
|
+
The acquisition parameters. Corresponds to "acq".
|
|
67
|
+
runs : str | array-like of str | None
|
|
68
|
+
The run number. Corresponds to "run".
|
|
69
|
+
processings : str | array-like of str | None
|
|
70
|
+
The processing label. Corresponds to "proc".
|
|
71
|
+
recordings : str | array-like of str | None
|
|
72
|
+
The recording name. Corresponds to "rec".
|
|
73
|
+
spaces : str | array-like of str | None
|
|
74
|
+
The coordinate space for anatomical and sensor location
|
|
75
|
+
files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
|
|
76
|
+
Corresponds to "space".
|
|
77
|
+
Note that valid values for ``space`` must come from a list
|
|
78
|
+
of BIDS keywords as described in the BIDS specification.
|
|
79
|
+
splits : str | array-like of str | None
|
|
80
|
+
The split of the continuous recording file for ``.fif`` data.
|
|
81
|
+
Corresponds to "split".
|
|
82
|
+
descriptions : str | array-like of str | None
|
|
83
|
+
This corresponds to the BIDS entity ``desc``. It is used to provide
|
|
84
|
+
additional information for derivative data, e.g., preprocessed data
|
|
85
|
+
may be assigned ``description='cleaned'``.
|
|
86
|
+
suffixes : str | array-like of str | None
|
|
87
|
+
The filename suffix. This is the entity after the
|
|
88
|
+
last ``_`` before the extension. E.g., ``'channels'``.
|
|
89
|
+
The following filename suffix's are accepted:
|
|
90
|
+
'meg', 'markers', 'eeg', 'ieeg', 'T1w',
|
|
91
|
+
'participants', 'scans', 'electrodes', 'coordsystem',
|
|
92
|
+
'channels', 'events', 'headshape', 'digitizer',
|
|
93
|
+
'beh', 'physio', 'stim'
|
|
94
|
+
extensions : str | array-like of str | None
|
|
95
|
+
The extension of the filename. E.g., ``'.json'``.
|
|
96
|
+
By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
|
|
97
|
+
datatypes : str | array-like of str | None
|
|
98
|
+
The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
|
|
99
|
+
``'ieeg'``.
|
|
100
|
+
check : bool
|
|
101
|
+
If ``True``, only returns paths that conform to BIDS. If ``False``
|
|
102
|
+
(default), the ``.check`` attribute of the returned
|
|
103
|
+
:class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
|
|
104
|
+
do conform to BIDS, and to ``False`` for those that don't.
|
|
105
|
+
preload : bool
|
|
106
|
+
If True, preload the data. Defaults to False.
|
|
107
|
+
n_jobs : int
|
|
108
|
+
Number of jobs to run in parallel. Defaults to 1.
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
reader_fn: Callable[[mne_bids.BIDSPath], Sequence],
|
|
116
|
+
pool_size: int = 4,
|
|
117
|
+
bids_paths: list[mne_bids.BIDSPath] | None = None,
|
|
118
|
+
root: Path | str | None = None,
|
|
119
|
+
subjects: str | list[str] | None = None,
|
|
120
|
+
sessions: str | list[str] | None = None,
|
|
121
|
+
tasks: str | list[str] | None = None,
|
|
122
|
+
acquisitions: str | list[str] | None = None,
|
|
123
|
+
runs: str | list[str] | None = None,
|
|
124
|
+
processings: str | list[str] | None = None,
|
|
125
|
+
recordings: str | list[str] | None = None,
|
|
126
|
+
spaces: str | list[str] | None = None,
|
|
127
|
+
splits: str | list[str] | None = None,
|
|
128
|
+
descriptions: str | list[str] | None = None,
|
|
129
|
+
suffixes: str | list[str] | None = None,
|
|
130
|
+
extensions: str | list[str] | None = [
|
|
131
|
+
".con",
|
|
132
|
+
".sqd",
|
|
133
|
+
".pdf",
|
|
134
|
+
".fif",
|
|
135
|
+
".ds",
|
|
136
|
+
".vhdr",
|
|
137
|
+
".set",
|
|
138
|
+
".edf",
|
|
139
|
+
".bdf",
|
|
140
|
+
".EDF",
|
|
141
|
+
".snirf",
|
|
142
|
+
".cdt",
|
|
143
|
+
".mef",
|
|
144
|
+
".nwb",
|
|
145
|
+
],
|
|
146
|
+
datatypes: str | list[str] | None = None,
|
|
147
|
+
check: bool = False,
|
|
148
|
+
):
|
|
149
|
+
if bids_paths is None:
|
|
150
|
+
bids_paths = mne_bids.find_matching_paths(
|
|
151
|
+
root=root,
|
|
152
|
+
subjects=subjects,
|
|
153
|
+
sessions=sessions,
|
|
154
|
+
tasks=tasks,
|
|
155
|
+
acquisitions=acquisitions,
|
|
156
|
+
runs=runs,
|
|
157
|
+
processings=processings,
|
|
158
|
+
recordings=recordings,
|
|
159
|
+
spaces=spaces,
|
|
160
|
+
splits=splits,
|
|
161
|
+
descriptions=descriptions,
|
|
162
|
+
suffixes=suffixes,
|
|
163
|
+
extensions=extensions,
|
|
164
|
+
datatypes=datatypes,
|
|
165
|
+
check=check,
|
|
166
|
+
ignore_json=True,
|
|
167
|
+
)
|
|
168
|
+
# Filter out _epo.fif files:
|
|
169
|
+
bids_paths = [
|
|
170
|
+
bids_path
|
|
171
|
+
for bids_path in bids_paths
|
|
172
|
+
if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
|
|
173
|
+
]
|
|
174
|
+
self.bids_paths = bids_paths
|
|
175
|
+
self.reader_fn = reader_fn
|
|
176
|
+
self.pool_size = pool_size
|
|
177
|
+
|
|
178
|
+
def __add__(self, other):
|
|
179
|
+
assert isinstance(other, BIDSIterableDataset)
|
|
180
|
+
return BIDSIterableDataset(
|
|
181
|
+
reader_fn=self.reader_fn,
|
|
182
|
+
bids_paths=self.bids_paths + other.bids_paths,
|
|
183
|
+
pool_size=self.pool_size,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def __iadd__(self, other):
|
|
187
|
+
assert isinstance(other, BIDSIterableDataset)
|
|
188
|
+
self.bids_paths += other.bids_paths
|
|
189
|
+
return self
|
|
190
|
+
|
|
191
|
+
def __iter__(self):
|
|
192
|
+
worker_info = get_worker_info()
|
|
193
|
+
if worker_info is None: # single-process data loading, return the full iterator
|
|
194
|
+
bids_paths = self.bids_paths
|
|
195
|
+
else: # in a worker process
|
|
196
|
+
# split workload
|
|
197
|
+
bids_paths = self.bids_paths[worker_info.id :: worker_info.num_workers]
|
|
198
|
+
|
|
199
|
+
pool = []
|
|
200
|
+
end = False
|
|
201
|
+
paths_it = iter(random.sample(bids_paths, k=len(bids_paths)))
|
|
202
|
+
while not (end and len(pool) == 0):
|
|
203
|
+
while not end and len(pool) < self.pool_size:
|
|
204
|
+
try:
|
|
205
|
+
bids_path = next(paths_it)
|
|
206
|
+
ds = self.reader_fn(bids_path)
|
|
207
|
+
if ds is None:
|
|
208
|
+
print(f"Skipping {bids_path} as it is too short.")
|
|
209
|
+
continue
|
|
210
|
+
idx = iter(random.sample(range(len(ds)), k=len(ds)))
|
|
211
|
+
pool.append((ds, idx))
|
|
212
|
+
except StopIteration:
|
|
213
|
+
end = True
|
|
214
|
+
i_pool = random.randint(0, len(pool) - 1)
|
|
215
|
+
ds, idx = pool[i_pool]
|
|
216
|
+
try:
|
|
217
|
+
i_ds = next(idx)
|
|
218
|
+
yield ds[i_ds]
|
|
219
|
+
except StopIteration:
|
|
220
|
+
pool.pop(i_pool)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This dataset is a BIDS-compatible version of the CHB-MIT Scalp EEG Database.
|
|
3
|
+
|
|
4
|
+
It reorganizes the file structure to comply with the BIDS specification. To this effect:
|
|
5
|
+
|
|
6
|
+
The data from subject chb21 was moved to sub-01/ses-02.
|
|
7
|
+
Metadata was organized according to BIDS.
|
|
8
|
+
Data in the EEG edf files was modified to keep only the 18 channels from a double banana bipolar montage.
|
|
9
|
+
Annotations were formatted as BIDS-score compatible `tsv` files.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
# Authors: Dan, Jonathan
|
|
13
|
+
# Shoeb, Ali (Data collector)
|
|
14
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
15
|
+
#
|
|
16
|
+
# License: BSD (3-clause)
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
from mne.datasets import fetch_dataset
|
|
22
|
+
|
|
23
|
+
from braindecode.datasets import BIDSDataset
|
|
24
|
+
from braindecode.datasets.utils import _correct_dataset_path
|
|
25
|
+
|
|
26
|
+
CHB_MIT_URL = "https://zenodo.org/records/10259996/files/BIDS_CHB-MIT.zip"
|
|
27
|
+
CHB_MIT_archive_name = "chb_mit_bids.zip"
|
|
28
|
+
CHB_MIT_folder_name = "CHB-MIT-BIDS-eeg-dataset"
|
|
29
|
+
CHB_MIT_dataset_name = "CHB-MIT-EEG-Corpus"
|
|
30
|
+
|
|
31
|
+
CHB_MIT_dataset_params = {
|
|
32
|
+
"dataset_name": CHB_MIT_dataset_name,
|
|
33
|
+
"url": CHB_MIT_URL,
|
|
34
|
+
"archive_name": CHB_MIT_archive_name,
|
|
35
|
+
"folder_name": CHB_MIT_folder_name,
|
|
36
|
+
"hash": "078f4e110e40d10fef1a38a892571ad24666c488e8118a01002c9224909256ed", # sha256
|
|
37
|
+
"config_key": CHB_MIT_dataset_name,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CHBMIT(BIDSDataset):
|
|
42
|
+
"""The Children's Hospital Boston EEG Dataset.
|
|
43
|
+
|
|
44
|
+
This database, collected at the Children's Hospital Boston, consists of EEG recordings
|
|
45
|
+
from pediatric subjects with intractable seizures. Subjects were monitored for up to
|
|
46
|
+
several days following withdrawal of anti-seizure medication in order to characterize
|
|
47
|
+
their seizures and assess their candidacy for surgical intervention.
|
|
48
|
+
|
|
49
|
+
**Description of the contents of the dataset:**
|
|
50
|
+
|
|
51
|
+
Each folder (sub-01, sub-01, etc.) contains between 9 and 42 continuous .edf
|
|
52
|
+
files from a single subject. Hardware limitations resulted in gaps between
|
|
53
|
+
consecutively-numbered .edf files, during which the signals were not recorded;
|
|
54
|
+
in most cases, the gaps are 10 seconds or less, but occasionally there are much
|
|
55
|
+
longer gaps. In order to protect the privacy of the subjects, all protected health
|
|
56
|
+
information (PHI) in the original .edf files has been replaced with surrogate information
|
|
57
|
+
in the files provided here. Dates in the original .edf files have been replaced by
|
|
58
|
+
surrogate dates, but the time relationships between the individual files belonging
|
|
59
|
+
to each case have been preserved. In most cases, the .edf files contain exactly one
|
|
60
|
+
hour of digitized EEG signals, although those belonging to case sub-10 are two hours
|
|
61
|
+
long, and those belonging to cases sub-04, sub-06, sub-07, sub-09, and sub-23 are
|
|
62
|
+
four hours long; occasionally, files in which seizures are recorded are shorter.
|
|
63
|
+
|
|
64
|
+
The EEG is recorded at 256 Hz with a 16-bit resolution. The recordings are
|
|
65
|
+
referenced in a double banana bipolar montage with 18 channels from the 10-20 electrode system.
|
|
66
|
+
|
|
67
|
+
This BIDS-compatible version of the dataset was published by Jonathan Dan :footcite:`Dan2025`
|
|
68
|
+
and is based on the original CHB MIT EEG Database :footcite:`Guttag2010`, :footcite:`Shoeb2009`.
|
|
69
|
+
|
|
70
|
+
.. versionadded:: 1.3
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
root : pathlib.Path | str
|
|
75
|
+
The root of the BIDS path.
|
|
76
|
+
subjects : str | array-like of str | None
|
|
77
|
+
The subject ID. Corresponds to "sub".
|
|
78
|
+
sessions : str | array-like of str | None
|
|
79
|
+
The acquisition session. Corresponds to "ses".
|
|
80
|
+
tasks : str | array-like of str | None
|
|
81
|
+
The experimental task. Corresponds to "task".
|
|
82
|
+
acquisitions : str | array-like of str | None
|
|
83
|
+
The acquisition parameters. Corresponds to "acq".
|
|
84
|
+
runs : str | array-like of str | None
|
|
85
|
+
The run number. Corresponds to "run".
|
|
86
|
+
processings : str | array-like of str | None
|
|
87
|
+
The processing label. Corresponds to "proc".
|
|
88
|
+
recordings : str | array-like of str | None
|
|
89
|
+
The recording name. Corresponds to "rec".
|
|
90
|
+
spaces : str | array-like of str | None
|
|
91
|
+
The coordinate space for anatomical and sensor location
|
|
92
|
+
files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
|
|
93
|
+
Corresponds to "space".
|
|
94
|
+
Note that valid values for ``space`` must come from a list
|
|
95
|
+
of BIDS keywords as described in the BIDS specification.
|
|
96
|
+
splits : str | array-like of str | None
|
|
97
|
+
The split of the continuous recording file for ``.fif`` data.
|
|
98
|
+
Corresponds to "split".
|
|
99
|
+
descriptions : str | array-like of str | None
|
|
100
|
+
This corresponds to the BIDS entity ``desc``. It is used to provide
|
|
101
|
+
additional information for derivative data, e.g., preprocessed data
|
|
102
|
+
may be assigned ``description='cleaned'``.
|
|
103
|
+
suffixes : str | array-like of str | None
|
|
104
|
+
The filename suffix. This is the entity after the
|
|
105
|
+
last ``_`` before the extension. E.g., ``'channels'``.
|
|
106
|
+
The following filename suffix's are accepted:
|
|
107
|
+
'meg', 'markers', 'eeg', 'ieeg', 'T1w',
|
|
108
|
+
'participants', 'scans', 'electrodes', 'coordsystem',
|
|
109
|
+
'channels', 'events', 'headshape', 'digitizer',
|
|
110
|
+
'beh', 'physio', 'stim'
|
|
111
|
+
extensions : str | array-like of str | None
|
|
112
|
+
The extension of the filename. E.g., ``'.json'``.
|
|
113
|
+
By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
|
|
114
|
+
datatypes : str | array-like of str | None
|
|
115
|
+
The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
|
|
116
|
+
``'ieeg'``.
|
|
117
|
+
check : bool
|
|
118
|
+
If ``True``, only returns paths that conform to BIDS. If ``False``
|
|
119
|
+
(default), the ``.check`` attribute of the returned
|
|
120
|
+
:class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
|
|
121
|
+
do conform to BIDS, and to ``False`` for those that don't.
|
|
122
|
+
preload : bool
|
|
123
|
+
If True, preload the data. Defaults to False.
|
|
124
|
+
n_jobs : int
|
|
125
|
+
Number of jobs to run in parallel. Defaults to 1.
|
|
126
|
+
|
|
127
|
+
References
|
|
128
|
+
----------
|
|
129
|
+
.. footbibliography::
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(self, root=None, *args, **kwargs):
|
|
133
|
+
# Download dataset if not present
|
|
134
|
+
if root is None:
|
|
135
|
+
path_root = fetch_dataset(
|
|
136
|
+
dataset_params=CHB_MIT_dataset_params,
|
|
137
|
+
path=None,
|
|
138
|
+
processor="unzip",
|
|
139
|
+
force_update=False,
|
|
140
|
+
)
|
|
141
|
+
# First time we fetch the dataset, we need to move the files to the
|
|
142
|
+
# correct directory.
|
|
143
|
+
path_root = _correct_dataset_path(
|
|
144
|
+
path_root, CHB_MIT_archive_name, "BIDS_CHB-MIT"
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
# Validate that the provided root is a valid BIDS dataset
|
|
148
|
+
if not Path(f"{root}/participants.tsv").exists():
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"The provided root directory {root} does not contain a valid "
|
|
151
|
+
"BIDS dataset (missing participants.tsv). Please ensure the "
|
|
152
|
+
"root points directly to the BIDS dataset directory."
|
|
153
|
+
)
|
|
154
|
+
path_root = root
|
|
155
|
+
|
|
156
|
+
kwargs["root"] = path_root
|
|
157
|
+
|
|
158
|
+
super().__init__(
|
|
159
|
+
*args,
|
|
160
|
+
extensions=".edf",
|
|
161
|
+
check=False,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|