braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,197 @@
1
+ # mypy: ignore-errors
2
+ """
3
+ Low-level Zarr I/O helpers for Hub integration.
4
+
5
+ These functions keep the Zarr serialization details isolated from hub.py.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import warnings
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ from mne.utils import _soft_import
17
+
18
+ zarr = _soft_import("zarr", purpose="hugging face integration", strict=False)
19
+
20
+
21
+ def _sanitize_for_json(obj):
22
+ """Replace NaN/Inf with None for valid JSON."""
23
+ if isinstance(obj, float):
24
+ if np.isnan(obj) or np.isinf(obj):
25
+ return None
26
+ return obj
27
+ if isinstance(obj, dict):
28
+ return {k: _sanitize_for_json(v) for k, v in obj.items()}
29
+ if isinstance(obj, list):
30
+ return [_sanitize_for_json(v) for v in obj]
31
+ if isinstance(obj, np.ndarray):
32
+ return _sanitize_for_json(obj.tolist())
33
+ return obj
34
+
35
+
36
+ def _restore_nan_from_json(obj):
37
+ """Restore NaN values from None in JSON-loaded data."""
38
+ if isinstance(obj, dict):
39
+ return {k: _restore_nan_from_json(v) for k, v in obj.items()}
40
+ if isinstance(obj, list):
41
+ if len(obj) > 0 and all(isinstance(x, (int, float, type(None))) for x in obj):
42
+ return [np.nan if x is None else x for x in obj]
43
+ return [_restore_nan_from_json(v) for v in obj]
44
+ return obj
45
+
46
+
47
+ def _save_windows_to_zarr(
48
+ grp, data, metadata, description, info, compressor, target_name
49
+ ):
50
+ """Save windowed data to Zarr group (low-level function)."""
51
+ data_array = data.astype(np.float32)
52
+ compressors_list = [compressor] if compressor is not None else None
53
+
54
+ grp.create_array(
55
+ "data",
56
+ data=data_array,
57
+ chunks=(1, data_array.shape[1], data_array.shape[2]),
58
+ compressors=compressors_list,
59
+ )
60
+
61
+ store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
62
+ metadata_path = Path(store_path) / grp.path / "metadata.tsv"
63
+ metadata.to_csv(metadata_path, sep="\t", index=True)
64
+
65
+ grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
66
+ grp.attrs["info"] = _sanitize_for_json(info)
67
+
68
+ if target_name is not None:
69
+ grp.attrs["target_name"] = target_name
70
+
71
+
72
+ def _save_eegwindows_to_zarr(
73
+ grp, raw, metadata, description, info, targets_from, last_target_only, compressor
74
+ ):
75
+ """Save EEG continuous raw data to Zarr group (low-level function)."""
76
+ continuous_data = raw.get_data()
77
+ continuous_float = continuous_data.astype(np.float32)
78
+ compressors_list = [compressor] if compressor is not None else None
79
+
80
+ grp.create_array(
81
+ "data",
82
+ data=continuous_float,
83
+ chunks=(continuous_float.shape[0], min(10000, continuous_float.shape[1])),
84
+ compressors=compressors_list,
85
+ )
86
+
87
+ store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
88
+ metadata_path = Path(store_path) / grp.path / "metadata.tsv"
89
+ metadata.to_csv(metadata_path, sep="\t", index=True)
90
+
91
+ grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
92
+ grp.attrs["info"] = _sanitize_for_json(info)
93
+ grp.attrs["targets_from"] = targets_from
94
+ grp.attrs["last_target_only"] = last_target_only
95
+
96
+
97
+ def _load_windows_from_zarr(grp, preload):
98
+ """Load windowed data from Zarr group (low-level function)."""
99
+ store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
100
+ metadata_path = Path(store_path) / grp.path / "metadata.tsv"
101
+ metadata = pd.read_csv(metadata_path, sep="\t", index_col=0)
102
+
103
+ description = pd.Series(grp.attrs["description"])
104
+ info_dict = _restore_nan_from_json(grp.attrs["info"])
105
+
106
+ if preload:
107
+ data = grp["data"][:]
108
+ else:
109
+ data = grp["data"][:]
110
+ warnings.warn(
111
+ "Lazy loading from Zarr not fully implemented yet. "
112
+ "Loading all data into memory.",
113
+ UserWarning,
114
+ )
115
+
116
+ target_name = grp.attrs.get("target_name", None)
117
+
118
+ return data, metadata, description, info_dict, target_name
119
+
120
+
121
+ def _load_eegwindows_from_zarr(grp, preload):
122
+ """Load EEG continuous raw data from Zarr group (low-level function)."""
123
+ store_path = getattr(grp.store, "path", getattr(grp.store, "root", None))
124
+ metadata_path = Path(store_path) / grp.path / "metadata.tsv"
125
+ metadata = pd.read_csv(metadata_path, sep="\t", index_col=0)
126
+
127
+ description = pd.Series(grp.attrs["description"])
128
+ info_dict = _restore_nan_from_json(grp.attrs["info"])
129
+
130
+ if preload:
131
+ data = grp["data"][:]
132
+ else:
133
+ data = grp["data"][:]
134
+ warnings.warn(
135
+ "Lazy loading from Zarr not fully implemented yet. "
136
+ "Loading all data into memory.",
137
+ UserWarning,
138
+ )
139
+
140
+ targets_from = grp.attrs.get("targets_from", "metadata")
141
+ last_target_only = grp.attrs.get("last_target_only", True)
142
+
143
+ return data, metadata, description, info_dict, targets_from, last_target_only
144
+
145
+
146
+ def _save_raw_to_zarr(grp, raw, description, info, target_name, compressor):
147
+ """Save RawDataset continuous raw data to Zarr group (low-level function)."""
148
+ continuous_data = raw.get_data()
149
+ continuous_float = continuous_data.astype(np.float32)
150
+ compressors_list = [compressor] if compressor is not None else None
151
+
152
+ grp.create_array(
153
+ "data",
154
+ data=continuous_float,
155
+ chunks=(continuous_float.shape[0], min(10000, continuous_float.shape[1])),
156
+ compressors=compressors_list,
157
+ )
158
+
159
+ grp.attrs["description"] = json.loads(description.to_json(date_format="iso"))
160
+ grp.attrs["info"] = _sanitize_for_json(info)
161
+
162
+ if target_name is not None:
163
+ grp.attrs["target_name"] = target_name
164
+
165
+
166
+ def _load_raw_from_zarr(grp, preload):
167
+ """Load RawDataset continuous raw data from Zarr group (low-level function)."""
168
+ description = pd.Series(grp.attrs["description"])
169
+ info_dict = _restore_nan_from_json(grp.attrs["info"])
170
+
171
+ if preload:
172
+ data = grp["data"][:]
173
+ else:
174
+ data = grp["data"][:]
175
+ warnings.warn(
176
+ "Lazy loading from Zarr not fully implemented yet. "
177
+ "Loading all data into memory.",
178
+ UserWarning,
179
+ )
180
+
181
+ target_name = grp.attrs.get("target_name", None)
182
+
183
+ return data, description, info_dict, target_name
184
+
185
+
186
+ def _create_compressor(compression, compression_level):
187
+ """Create a Zarr v3 compressor codec."""
188
+ if zarr is False:
189
+ raise ImportError(
190
+ "Zarr is not installed. Install with: pip install braindecode[hub]"
191
+ )
192
+
193
+ if compression is None or compression not in ("blosc", "zstd", "gzip"):
194
+ return None
195
+
196
+ name = "zstd" if compression == "blosc" else compression
197
+ return {"name": name, "configuration": {"level": compression_level}}
@@ -0,0 +1,114 @@
1
+ # mypy: ignore-errors
2
+ """
3
+ Shared validation utilities for Hub format operations.
4
+
5
+ This module provides validation functions used by hub.py to avoid code duplication.
6
+ """
7
+
8
+ # Authors: Kuntal Kokate
9
+ #
10
+ # License: BSD (3-clause)
11
+
12
+ from typing import Any, List, Tuple
13
+
14
+ from ..registry import get_dataset_type
15
+
16
+
17
+ def validate_dataset_uniformity(
18
+ datasets: List[Any],
19
+ ) -> Tuple[str, List[str], float]:
20
+ """
21
+ Validate all datasets have uniform type, channels, and sampling frequency.
22
+
23
+ Parameters
24
+ ----------
25
+ datasets : list
26
+ List of dataset objects to validate.
27
+
28
+ Returns
29
+ -------
30
+ dataset_type : str
31
+ The validated dataset type (WindowsDataset, EEGWindowsDataset, or RawDataset).
32
+ first_ch_names : list of str
33
+ Channel names from the first dataset.
34
+ first_sfreq : float
35
+ Sampling frequency from the first dataset.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ If datasets have mixed types, inconsistent channels, or inconsistent
41
+ sampling frequencies.
42
+ TypeError
43
+ If dataset type is not supported.
44
+ """
45
+ if not datasets:
46
+ raise ValueError("No datasets provided for validation.")
47
+
48
+ first_ds = datasets[0]
49
+ dataset_type = get_dataset_type(first_ds)
50
+
51
+ # Get reference channel names and sampling frequency from the first dataset
52
+ first_ch_names, first_sfreq = _get_ch_names_and_sfreq(first_ds, dataset_type)
53
+
54
+ # Validate all datasets have uniform properties
55
+ for i, ds in enumerate(datasets):
56
+ ds_type = get_dataset_type(ds)
57
+ if ds_type != dataset_type:
58
+ raise ValueError(
59
+ f"Mixed dataset types in concat: dataset 0 is {dataset_type} "
60
+ f"but dataset {i} is {ds_type}"
61
+ )
62
+
63
+ ch_names, sfreq = _get_ch_names_and_sfreq(ds, dataset_type)
64
+
65
+ if ch_names != first_ch_names:
66
+ raise ValueError(
67
+ f"Inconsistent channel names: dataset 0 has {first_ch_names} "
68
+ f"but dataset {i} has {ch_names}"
69
+ )
70
+
71
+ if sfreq != first_sfreq:
72
+ _raise_sfreq_error(first_sfreq, sfreq, i)
73
+
74
+ return dataset_type, first_ch_names, first_sfreq
75
+
76
+
77
+ def _get_ch_names_and_sfreq(ds: Any, dataset_type: str) -> Tuple[List[str], float]:
78
+ """Return (ch_names, sfreq) for supported dataset types."""
79
+ if dataset_type == "WindowsDataset":
80
+ obj = ds.windows
81
+ elif dataset_type in ("EEGWindowsDataset", "RawDataset"):
82
+ obj = ds.raw
83
+ else:
84
+ raise TypeError(f"Unsupported dataset type: {dataset_type}")
85
+
86
+ return obj.ch_names, obj.info["sfreq"]
87
+
88
+
89
+ def _raise_sfreq_error(expected: float, actual: float, idx: int):
90
+ """
91
+ Raise standardized sampling frequency error.
92
+
93
+ Parameters
94
+ ----------
95
+ expected : float
96
+ Expected sampling frequency from dataset 0.
97
+ actual : float
98
+ Actual sampling frequency from current dataset.
99
+ idx : int
100
+ Index of the dataset with inconsistent sampling frequency.
101
+
102
+ Raises
103
+ ------
104
+ ValueError
105
+ Always raised with standardized error message.
106
+ """
107
+ raise ValueError(
108
+ f"Inconsistent sampling frequencies: dataset 0 has {expected} Hz "
109
+ f"but dataset {idx} has {actual} Hz. "
110
+ f"Please resample all datasets to a common frequency before saving. "
111
+ f"Use braindecode.preprocessing.preprocess("
112
+ f"[Preprocessor(Resample(sfreq={expected}))], concat_ds) "
113
+ f"to resample your datasets."
114
+ )
@@ -0,0 +1,220 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Callable, Sequence
6
+
7
+ import mne_bids
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+
11
+ class BIDSIterableDataset(IterableDataset):
12
+ """Dataset for loading BIDS.
13
+
14
+ .. warning::
15
+ This class is experimental and may change in the future.
16
+
17
+ .. warning::
18
+ This dataset is not consistent with the Braindecode API.
19
+
20
+ This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
21
+ as it will be used to find the files to load. The default ``extensions`` parameter was changed.
22
+
23
+ More information on BIDS (Brain Imaging Data Structure)
24
+ can be found at https://bids.neuroimaging.io
25
+
26
+ Examples
27
+ --------
28
+ >>> from braindecode.datasets import BaseConcatDataset, RawDataset, RecordDataset
29
+ >>> from braindecode.datasets.bids import BIDSIterableDataset
30
+ >>> from braindecode.preprocessing import create_fixed_length_windows
31
+ >>>
32
+ >>> def my_reader_fn(path):
33
+ ... raw = mne_bids.read_raw_bids(path)
34
+ ... ds: RecordDataset = RawDataset(raw, description={"path": path.fpath})
35
+ ... windows_ds = create_fixed_length_windows(
36
+ ... BaseConcatDataset([ds]),
37
+ ... window_size_samples=400,
38
+ ... window_stride_samples=200,
39
+ ... )
40
+ ... return windows_ds
41
+ >>>
42
+ >>> dataset = BIDSIterableDataset(
43
+ ... reader_fn=my_reader_fn,
44
+ ... root="root/of/my/bids/dataset/",
45
+ ... )
46
+
47
+ Parameters
48
+ ----------
49
+ reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
50
+ A function that takes a BIDSPath and returns a dataset (e.g., a
51
+ RecordDataset or BaseConcatDataset of RecordDataset).
52
+ pool_size : int
53
+ The number of recordings to read and sample from.
54
+ bids_paths : list[mne_bids.BIDSPath] | None
55
+ A list of BIDSPaths to load. If None, will use the paths found by
56
+ :func:`mne_bids.find_matching_paths` and the arguments below.
57
+ root : pathlib.Path | str
58
+ The root of the BIDS path.
59
+ subjects : str | array-like of str | None
60
+ The subject ID. Corresponds to "sub".
61
+ sessions : str | array-like of str | None
62
+ The acquisition session. Corresponds to "ses".
63
+ tasks : str | array-like of str | None
64
+ The experimental task. Corresponds to "task".
65
+ acquisitions : str | array-like of str | None
66
+ The acquisition parameters. Corresponds to "acq".
67
+ runs : str | array-like of str | None
68
+ The run number. Corresponds to "run".
69
+ processings : str | array-like of str | None
70
+ The processing label. Corresponds to "proc".
71
+ recordings : str | array-like of str | None
72
+ The recording name. Corresponds to "rec".
73
+ spaces : str | array-like of str | None
74
+ The coordinate space for anatomical and sensor location
75
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
76
+ Corresponds to "space".
77
+ Note that valid values for ``space`` must come from a list
78
+ of BIDS keywords as described in the BIDS specification.
79
+ splits : str | array-like of str | None
80
+ The split of the continuous recording file for ``.fif`` data.
81
+ Corresponds to "split".
82
+ descriptions : str | array-like of str | None
83
+ This corresponds to the BIDS entity ``desc``. It is used to provide
84
+ additional information for derivative data, e.g., preprocessed data
85
+ may be assigned ``description='cleaned'``.
86
+ suffixes : str | array-like of str | None
87
+ The filename suffix. This is the entity after the
88
+ last ``_`` before the extension. E.g., ``'channels'``.
89
+ The following filename suffix's are accepted:
90
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
91
+ 'participants', 'scans', 'electrodes', 'coordsystem',
92
+ 'channels', 'events', 'headshape', 'digitizer',
93
+ 'beh', 'physio', 'stim'
94
+ extensions : str | array-like of str | None
95
+ The extension of the filename. E.g., ``'.json'``.
96
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
97
+ datatypes : str | array-like of str | None
98
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
99
+ ``'ieeg'``.
100
+ check : bool
101
+ If ``True``, only returns paths that conform to BIDS. If ``False``
102
+ (default), the ``.check`` attribute of the returned
103
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
104
+ do conform to BIDS, and to ``False`` for those that don't.
105
+ preload : bool
106
+ If True, preload the data. Defaults to False.
107
+ n_jobs : int
108
+ Number of jobs to run in parallel. Defaults to 1.
109
+
110
+
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ reader_fn: Callable[[mne_bids.BIDSPath], Sequence],
116
+ pool_size: int = 4,
117
+ bids_paths: list[mne_bids.BIDSPath] | None = None,
118
+ root: Path | str | None = None,
119
+ subjects: str | list[str] | None = None,
120
+ sessions: str | list[str] | None = None,
121
+ tasks: str | list[str] | None = None,
122
+ acquisitions: str | list[str] | None = None,
123
+ runs: str | list[str] | None = None,
124
+ processings: str | list[str] | None = None,
125
+ recordings: str | list[str] | None = None,
126
+ spaces: str | list[str] | None = None,
127
+ splits: str | list[str] | None = None,
128
+ descriptions: str | list[str] | None = None,
129
+ suffixes: str | list[str] | None = None,
130
+ extensions: str | list[str] | None = [
131
+ ".con",
132
+ ".sqd",
133
+ ".pdf",
134
+ ".fif",
135
+ ".ds",
136
+ ".vhdr",
137
+ ".set",
138
+ ".edf",
139
+ ".bdf",
140
+ ".EDF",
141
+ ".snirf",
142
+ ".cdt",
143
+ ".mef",
144
+ ".nwb",
145
+ ],
146
+ datatypes: str | list[str] | None = None,
147
+ check: bool = False,
148
+ ):
149
+ if bids_paths is None:
150
+ bids_paths = mne_bids.find_matching_paths(
151
+ root=root,
152
+ subjects=subjects,
153
+ sessions=sessions,
154
+ tasks=tasks,
155
+ acquisitions=acquisitions,
156
+ runs=runs,
157
+ processings=processings,
158
+ recordings=recordings,
159
+ spaces=spaces,
160
+ splits=splits,
161
+ descriptions=descriptions,
162
+ suffixes=suffixes,
163
+ extensions=extensions,
164
+ datatypes=datatypes,
165
+ check=check,
166
+ ignore_json=True,
167
+ )
168
+ # Filter out _epo.fif files:
169
+ bids_paths = [
170
+ bids_path
171
+ for bids_path in bids_paths
172
+ if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
173
+ ]
174
+ self.bids_paths = bids_paths
175
+ self.reader_fn = reader_fn
176
+ self.pool_size = pool_size
177
+
178
+ def __add__(self, other):
179
+ assert isinstance(other, BIDSIterableDataset)
180
+ return BIDSIterableDataset(
181
+ reader_fn=self.reader_fn,
182
+ bids_paths=self.bids_paths + other.bids_paths,
183
+ pool_size=self.pool_size,
184
+ )
185
+
186
+ def __iadd__(self, other):
187
+ assert isinstance(other, BIDSIterableDataset)
188
+ self.bids_paths += other.bids_paths
189
+ return self
190
+
191
+ def __iter__(self):
192
+ worker_info = get_worker_info()
193
+ if worker_info is None: # single-process data loading, return the full iterator
194
+ bids_paths = self.bids_paths
195
+ else: # in a worker process
196
+ # split workload
197
+ bids_paths = self.bids_paths[worker_info.id :: worker_info.num_workers]
198
+
199
+ pool = []
200
+ end = False
201
+ paths_it = iter(random.sample(bids_paths, k=len(bids_paths)))
202
+ while not (end and len(pool) == 0):
203
+ while not end and len(pool) < self.pool_size:
204
+ try:
205
+ bids_path = next(paths_it)
206
+ ds = self.reader_fn(bids_path)
207
+ if ds is None:
208
+ print(f"Skipping {bids_path} as it is too short.")
209
+ continue
210
+ idx = iter(random.sample(range(len(ds)), k=len(ds)))
211
+ pool.append((ds, idx))
212
+ except StopIteration:
213
+ end = True
214
+ i_pool = random.randint(0, len(pool) - 1)
215
+ ds, idx = pool[i_pool]
216
+ try:
217
+ i_ds = next(idx)
218
+ yield ds[i_ds]
219
+ except StopIteration:
220
+ pool.pop(i_pool)
@@ -0,0 +1,163 @@
1
+ """
2
+ This dataset is a BIDS-compatible version of the CHB-MIT Scalp EEG Database.
3
+
4
+ It reorganizes the file structure to comply with the BIDS specification. To this effect:
5
+
6
+ The data from subject chb21 was moved to sub-01/ses-02.
7
+ Metadata was organized according to BIDS.
8
+ Data in the EEG edf files was modified to keep only the 18 channels from a double banana bipolar montage.
9
+ Annotations were formatted as BIDS-score compatible `tsv` files.
10
+ """
11
+
12
+ # Authors: Dan, Jonathan
13
+ # Shoeb, Ali (Data collector)
14
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
15
+ #
16
+ # License: BSD (3-clause)
17
+ from __future__ import annotations
18
+
19
+ from pathlib import Path
20
+
21
+ from mne.datasets import fetch_dataset
22
+
23
+ from braindecode.datasets import BIDSDataset
24
+ from braindecode.datasets.utils import _correct_dataset_path
25
+
26
+ CHB_MIT_URL = "https://zenodo.org/records/10259996/files/BIDS_CHB-MIT.zip"
27
+ CHB_MIT_archive_name = "chb_mit_bids.zip"
28
+ CHB_MIT_folder_name = "CHB-MIT-BIDS-eeg-dataset"
29
+ CHB_MIT_dataset_name = "CHB-MIT-EEG-Corpus"
30
+
31
+ CHB_MIT_dataset_params = {
32
+ "dataset_name": CHB_MIT_dataset_name,
33
+ "url": CHB_MIT_URL,
34
+ "archive_name": CHB_MIT_archive_name,
35
+ "folder_name": CHB_MIT_folder_name,
36
+ "hash": "078f4e110e40d10fef1a38a892571ad24666c488e8118a01002c9224909256ed", # sha256
37
+ "config_key": CHB_MIT_dataset_name,
38
+ }
39
+
40
+
41
+ class CHBMIT(BIDSDataset):
42
+ """The Children's Hospital Boston EEG Dataset.
43
+
44
+ This database, collected at the Children's Hospital Boston, consists of EEG recordings
45
+ from pediatric subjects with intractable seizures. Subjects were monitored for up to
46
+ several days following withdrawal of anti-seizure medication in order to characterize
47
+ their seizures and assess their candidacy for surgical intervention.
48
+
49
+ **Description of the contents of the dataset:**
50
+
51
+ Each folder (sub-01, sub-01, etc.) contains between 9 and 42 continuous .edf
52
+ files from a single subject. Hardware limitations resulted in gaps between
53
+ consecutively-numbered .edf files, during which the signals were not recorded;
54
+ in most cases, the gaps are 10 seconds or less, but occasionally there are much
55
+ longer gaps. In order to protect the privacy of the subjects, all protected health
56
+ information (PHI) in the original .edf files has been replaced with surrogate information
57
+ in the files provided here. Dates in the original .edf files have been replaced by
58
+ surrogate dates, but the time relationships between the individual files belonging
59
+ to each case have been preserved. In most cases, the .edf files contain exactly one
60
+ hour of digitized EEG signals, although those belonging to case sub-10 are two hours
61
+ long, and those belonging to cases sub-04, sub-06, sub-07, sub-09, and sub-23 are
62
+ four hours long; occasionally, files in which seizures are recorded are shorter.
63
+
64
+ The EEG is recorded at 256 Hz with a 16-bit resolution. The recordings are
65
+ referenced in a double banana bipolar montage with 18 channels from the 10-20 electrode system.
66
+
67
+ This BIDS-compatible version of the dataset was published by Jonathan Dan :footcite:`Dan2025`
68
+ and is based on the original CHB MIT EEG Database :footcite:`Guttag2010`, :footcite:`Shoeb2009`.
69
+
70
+ .. versionadded:: 1.3
71
+
72
+ Parameters
73
+ ----------
74
+ root : pathlib.Path | str
75
+ The root of the BIDS path.
76
+ subjects : str | array-like of str | None
77
+ The subject ID. Corresponds to "sub".
78
+ sessions : str | array-like of str | None
79
+ The acquisition session. Corresponds to "ses".
80
+ tasks : str | array-like of str | None
81
+ The experimental task. Corresponds to "task".
82
+ acquisitions : str | array-like of str | None
83
+ The acquisition parameters. Corresponds to "acq".
84
+ runs : str | array-like of str | None
85
+ The run number. Corresponds to "run".
86
+ processings : str | array-like of str | None
87
+ The processing label. Corresponds to "proc".
88
+ recordings : str | array-like of str | None
89
+ The recording name. Corresponds to "rec".
90
+ spaces : str | array-like of str | None
91
+ The coordinate space for anatomical and sensor location
92
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
93
+ Corresponds to "space".
94
+ Note that valid values for ``space`` must come from a list
95
+ of BIDS keywords as described in the BIDS specification.
96
+ splits : str | array-like of str | None
97
+ The split of the continuous recording file for ``.fif`` data.
98
+ Corresponds to "split".
99
+ descriptions : str | array-like of str | None
100
+ This corresponds to the BIDS entity ``desc``. It is used to provide
101
+ additional information for derivative data, e.g., preprocessed data
102
+ may be assigned ``description='cleaned'``.
103
+ suffixes : str | array-like of str | None
104
+ The filename suffix. This is the entity after the
105
+ last ``_`` before the extension. E.g., ``'channels'``.
106
+ The following filename suffix's are accepted:
107
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
108
+ 'participants', 'scans', 'electrodes', 'coordsystem',
109
+ 'channels', 'events', 'headshape', 'digitizer',
110
+ 'beh', 'physio', 'stim'
111
+ extensions : str | array-like of str | None
112
+ The extension of the filename. E.g., ``'.json'``.
113
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
114
+ datatypes : str | array-like of str | None
115
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
116
+ ``'ieeg'``.
117
+ check : bool
118
+ If ``True``, only returns paths that conform to BIDS. If ``False``
119
+ (default), the ``.check`` attribute of the returned
120
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
121
+ do conform to BIDS, and to ``False`` for those that don't.
122
+ preload : bool
123
+ If True, preload the data. Defaults to False.
124
+ n_jobs : int
125
+ Number of jobs to run in parallel. Defaults to 1.
126
+
127
+ References
128
+ ----------
129
+ .. footbibliography::
130
+ """
131
+
132
+ def __init__(self, root=None, *args, **kwargs):
133
+ # Download dataset if not present
134
+ if root is None:
135
+ path_root = fetch_dataset(
136
+ dataset_params=CHB_MIT_dataset_params,
137
+ path=None,
138
+ processor="unzip",
139
+ force_update=False,
140
+ )
141
+ # First time we fetch the dataset, we need to move the files to the
142
+ # correct directory.
143
+ path_root = _correct_dataset_path(
144
+ path_root, CHB_MIT_archive_name, "BIDS_CHB-MIT"
145
+ )
146
+ else:
147
+ # Validate that the provided root is a valid BIDS dataset
148
+ if not Path(f"{root}/participants.tsv").exists():
149
+ raise ValueError(
150
+ f"The provided root directory {root} does not contain a valid "
151
+ "BIDS dataset (missing participants.tsv). Please ensure the "
152
+ "root points directly to the BIDS dataset directory."
153
+ )
154
+ path_root = root
155
+
156
+ kwargs["root"] = path_root
157
+
158
+ super().__init__(
159
+ *args,
160
+ extensions=".edf",
161
+ check=False,
162
+ **kwargs,
163
+ )