braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,170 @@
1
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ from __future__ import annotations
7
+
8
+ import mne
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
13
+
14
+
15
+ def create_from_mne_raw(
16
+ raws: list[mne.io.BaseRaw],
17
+ trial_start_offset_samples: int,
18
+ trial_stop_offset_samples: int,
19
+ window_size_samples: int,
20
+ window_stride_samples: int,
21
+ drop_last_window: bool,
22
+ descriptions: list[dict | pd.Series] | None = None,
23
+ mapping: dict[str, int] | None = None,
24
+ preload: bool = False,
25
+ drop_bad_windows: bool = True,
26
+ accepted_bads_ratio: float = 0.0,
27
+ ) -> BaseConcatDataset:
28
+ """Create WindowsDatasets from mne.RawArrays.
29
+
30
+ Parameters
31
+ ----------
32
+ raws : array-like
33
+ list of mne.RawArrays
34
+ trial_start_offset_samples : int
35
+ start offset from original trial onsets in samples
36
+ trial_stop_offset_samples : int
37
+ stop offset from original trial stop in samples
38
+ window_size_samples : int
39
+ window size
40
+ window_stride_samples : int
41
+ stride between windows
42
+ drop_last_window : bool
43
+ whether or not have a last overlapping window, when
44
+ windows do not equally divide the continuous signal
45
+ descriptions : array-like
46
+ list of dicts or pandas.Series with additional information about the raws
47
+ mapping : dict(str: int)
48
+ mapping from event description to target value
49
+ preload : bool
50
+ if True, preload the data of the Epochs objects.
51
+ drop_bad_windows : bool
52
+ If True, call `.drop_bad()` on the resulting mne.Epochs object. This
53
+ step allows identifying e.g., windows that fall outside of the
54
+ continuous recording. It is suggested to run this step here as otherwise
55
+ the BaseConcatDataset has to be updated as well.
56
+ accepted_bads_ratio : float, optional
57
+ Acceptable proportion of trials withinconsistent length in a raw. If
58
+ the number of trials whose length is exceeded by the window size is
59
+ smaller than this, then only the corresponding trials are dropped, but
60
+ the computation continues. Otherwise, an error is raised. Defaults to
61
+ 0.0 (raise an error).
62
+
63
+ Returns
64
+ -------
65
+ windows_datasets : BaseConcatDataset
66
+ X and y transformed to a dataset format that is compatible with skorch
67
+ and braindecode
68
+ """
69
+ # Prevent circular import
70
+ from ..preprocessing.windowers import create_windows_from_events
71
+
72
+ if descriptions is not None:
73
+ if len(descriptions) != len(raws):
74
+ raise ValueError(
75
+ f"length of 'raws' ({len(raws)}) and 'description' "
76
+ f"({len(descriptions)}) has to match"
77
+ )
78
+ base_datasets = [RawDataset(raw, desc) for raw, desc in zip(raws, descriptions)]
79
+ else:
80
+ base_datasets = [RawDataset(raw) for raw in raws]
81
+
82
+ base_datasets = BaseConcatDataset(base_datasets)
83
+ windows_datasets = create_windows_from_events(
84
+ base_datasets,
85
+ trial_start_offset_samples=trial_start_offset_samples,
86
+ trial_stop_offset_samples=trial_stop_offset_samples,
87
+ window_size_samples=window_size_samples,
88
+ window_stride_samples=window_stride_samples,
89
+ drop_last_window=drop_last_window,
90
+ mapping=mapping,
91
+ drop_bad_windows=drop_bad_windows,
92
+ preload=preload,
93
+ accepted_bads_ratio=accepted_bads_ratio,
94
+ )
95
+ return windows_datasets
96
+
97
+
98
+ def create_from_mne_epochs(
99
+ list_of_epochs: list[mne.BaseEpochs],
100
+ window_size_samples: int,
101
+ window_stride_samples: int,
102
+ drop_last_window: bool,
103
+ ) -> BaseConcatDataset:
104
+ """Create WindowsDatasets from mne.Epochs.
105
+
106
+ Parameters
107
+ ----------
108
+ list_of_epochs : array-like
109
+ list of mne.Epochs
110
+ window_size_samples : int
111
+ window size
112
+ window_stride_samples : int
113
+ stride between windows
114
+ drop_last_window : bool
115
+ whether or not have a last overlapping window, when
116
+ windows do not equally divide the continuous signal
117
+
118
+ Returns
119
+ -------
120
+ windows_datasets : BaseConcatDataset
121
+ X and y transformed to a dataset format that is compatible with skorch
122
+ and braindecode
123
+ """
124
+ # Prevent circular import
125
+ from ..preprocessing.windowers import _check_windowing_arguments
126
+
127
+ _check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
128
+
129
+ list_of_windows_ds = []
130
+ for epochs in list_of_epochs:
131
+ event_descriptions = epochs.events[:, 2]
132
+ original_trial_starts = epochs.events[:, 0]
133
+ stop = len(epochs.times) - window_size_samples
134
+
135
+ # already includes last incomplete window start
136
+ starts = np.arange(0, stop + 1, window_stride_samples)
137
+
138
+ if not drop_last_window and starts[-1] < stop:
139
+ # if last window does not end at trial stop, make it stop there
140
+ starts = np.append(starts, stop)
141
+
142
+ fake_events = [[start, window_size_samples, -1] for start in starts]
143
+
144
+ for trial_i, trial in enumerate(epochs):
145
+ metadata = pd.DataFrame(
146
+ {
147
+ "i_window_in_trial": np.arange(len(fake_events)),
148
+ "i_start_in_trial": starts + original_trial_starts[trial_i],
149
+ "i_stop_in_trial": starts
150
+ + original_trial_starts[trial_i]
151
+ + window_size_samples,
152
+ "target": len(fake_events) * [event_descriptions[trial_i]],
153
+ }
154
+ )
155
+ # window size - 1, since tmax is inclusive
156
+ mne_epochs = mne.Epochs(
157
+ mne.io.RawArray(trial, epochs.info),
158
+ fake_events,
159
+ baseline=None,
160
+ tmin=0,
161
+ tmax=(window_size_samples - 1) / epochs.info["sfreq"],
162
+ metadata=metadata,
163
+ )
164
+
165
+ mne_epochs.drop_bad(reject=None, flat=None)
166
+
167
+ windows_ds = WindowsDataset(mne_epochs)
168
+ list_of_windows_ds.append(windows_ds)
169
+
170
+ return BaseConcatDataset(list_of_windows_ds)
@@ -0,0 +1,219 @@
1
+ """Dataset objects for some public datasets."""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Pierre Guetschel <pierre.guetschel@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ from __future__ import annotations
12
+
13
+ import warnings
14
+ from typing import Any
15
+
16
+ import mne
17
+ import pandas as pd
18
+ from mne.utils import deprecated
19
+
20
+ from braindecode.util import _update_moabb_docstring
21
+
22
+ from .base import BaseConcatDataset, RawDataset
23
+
24
+
25
+ def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
26
+ # soft dependency on moabb
27
+ from moabb.datasets.utils import dataset_list
28
+
29
+ for dataset in dataset_list:
30
+ if dataset_name == dataset.__name__:
31
+ # return an instance of the found dataset class
32
+ if dataset_kwargs is None:
33
+ return dataset()
34
+ else:
35
+ return dataset(**dataset_kwargs)
36
+ raise ValueError(f"{dataset_name} not found in moabb datasets")
37
+
38
+
39
+ def _fetch_and_unpack_moabb_data(dataset, subject_ids=None, dataset_load_kwargs=None):
40
+ if dataset_load_kwargs is None:
41
+ data = dataset.get_data(subject_ids)
42
+ else:
43
+ data = dataset.get_data(subjects=subject_ids, **dataset_load_kwargs)
44
+
45
+ raws, subject_ids, session_ids, run_ids = [], [], [], []
46
+ for subj_id, subj_data in data.items():
47
+ for sess_id, sess_data in subj_data.items():
48
+ for run_id, raw in sess_data.items():
49
+ annots = _annotations_from_moabb_stim_channel(raw, dataset)
50
+ raw.set_annotations(annots)
51
+ raws.append(raw)
52
+ subject_ids.append(subj_id)
53
+ session_ids.append(sess_id)
54
+ run_ids.append(run_id)
55
+ description = pd.DataFrame(
56
+ {"subject": subject_ids, "session": session_ids, "run": run_ids}
57
+ )
58
+ return raws, description
59
+
60
+
61
+ def _annotations_from_moabb_stim_channel(raw, dataset):
62
+ # find events from the stim channel
63
+ stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
64
+ if len(stim_channels) > 0:
65
+ # returns an empty array if none found
66
+ events = mne.find_events(raw, shortest_event=0, verbose=False)
67
+ event_id = dataset.event_id
68
+ else:
69
+ events, event_id = mne.events_from_annotations(raw, verbose=False)
70
+
71
+ # get annotations from events
72
+ event_desc = {k: v for v, k in event_id.items()}
73
+ annots = mne.annotations_from_events(events, raw.info["sfreq"], event_desc)
74
+
75
+ # set trial on and offset given by moabb
76
+ onset, offset = dataset.interval
77
+ annots.onset += onset
78
+ annots.duration += offset - onset
79
+ return annots
80
+
81
+
82
+ def fetch_data_with_moabb(
83
+ dataset_name: str,
84
+ subject_ids: list[int] | int | None = None,
85
+ dataset_kwargs: dict[str, Any] | None = None,
86
+ dataset_load_kwargs: dict[str, Any] | None = None,
87
+ ) -> tuple[list[mne.io.Raw], pd.DataFrame]:
88
+ # ToDo: update path to where moabb downloads / looks for the data
89
+ """Fetch data using moabb.
90
+
91
+ Parameters
92
+ ----------
93
+ dataset_name : str | moabb.datasets.base.BaseDataset
94
+ the name of a dataset included in moabb
95
+ subject_ids : list(int) | int
96
+ (list of) int of subject(s) to be fetched
97
+ dataset_kwargs : dict, optional
98
+ optional dictionary containing keyword arguments
99
+ to pass to the moabb dataset when instantiating it.
100
+ data_load_kwargs : dict, optional
101
+ optional dictionary containing keyword arguments
102
+ to pass to the moabb dataset's load_data method.
103
+ Allows using the moabb cache_config=None and
104
+ process_pipeline=None.
105
+
106
+ Returns
107
+ -------
108
+ raws : mne.Raw
109
+ info : pandas.DataFrame
110
+ """
111
+ if isinstance(dataset_name, str):
112
+ dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
113
+ else:
114
+ from moabb.datasets.base import BaseDataset
115
+
116
+ if isinstance(dataset_name, BaseDataset):
117
+ dataset = dataset_name
118
+
119
+ subject_id = [subject_ids] if isinstance(subject_ids, int) else subject_ids
120
+ return _fetch_and_unpack_moabb_data(
121
+ dataset, subject_id, dataset_load_kwargs=dataset_load_kwargs
122
+ )
123
+
124
+
125
+ class MOABBDataset(BaseConcatDataset):
126
+ """A class for moabb datasets.
127
+
128
+ Parameters
129
+ ----------
130
+ dataset_name : str
131
+ name of dataset included in moabb to be fetched
132
+ subject_ids : list(int) | int | None
133
+ (list of) int of subject(s) to be fetched. If None, data of all
134
+ subjects is fetched.
135
+ dataset_kwargs : dict, optional
136
+ optional dictionary containing keyword arguments
137
+ to pass to the moabb dataset when instantiating it.
138
+ dataset_load_kwargs : dict, optional
139
+ optional dictionary containing keyword arguments
140
+ to pass to the moabb dataset's load_data method.
141
+ Allows using the moabb cache_config=None and
142
+ process_pipeline=None.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ dataset_name: str,
148
+ subject_ids: list[int] | int | None = None,
149
+ dataset_kwargs: dict[str, Any] | None = None,
150
+ dataset_load_kwargs: dict[str, Any] | None = None,
151
+ ):
152
+ # soft dependency on moabb
153
+ from moabb import __version__ as moabb_version # type: ignore
154
+
155
+ if moabb_version == "1.0.0":
156
+ warnings.warn(
157
+ "moabb version 1.0.0 generates incorrect annotations. "
158
+ "Please update to another version, version 0.5 or 1.1.0 "
159
+ )
160
+
161
+ raws, description = fetch_data_with_moabb(
162
+ dataset_name,
163
+ subject_ids,
164
+ dataset_kwargs,
165
+ dataset_load_kwargs=dataset_load_kwargs,
166
+ )
167
+ all_base_ds = [
168
+ RawDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
169
+ ]
170
+ super().__init__(all_base_ds)
171
+
172
+
173
+ class BNCI2014_001(MOABBDataset):
174
+ doc = """See moabb.datasets.bnci.BNCI2014_001
175
+
176
+ Parameters
177
+ ----------
178
+ subject_ids: list(int) | int | None
179
+ (list of) int of subject(s) to be fetched. If None, data of all
180
+ subjects is fetched.
181
+ """
182
+ try:
183
+ from moabb.datasets import BNCI2014_001
184
+
185
+ __doc__ = _update_moabb_docstring(BNCI2014_001, doc)
186
+ except ModuleNotFoundError:
187
+ pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
188
+
189
+ def __init__(self, subject_ids):
190
+ super().__init__("BNCI2014_001", subject_ids=subject_ids)
191
+
192
+
193
+ class HGD(MOABBDataset):
194
+ doc = """See moabb.datasets.schirrmeister2017.Schirrmeister2017
195
+
196
+ Parameters
197
+ ----------
198
+ subject_ids: list(int) | int | None
199
+ (list of) int of subject(s) to be fetched. If None, data of all
200
+ subjects is fetched.
201
+ """
202
+ try:
203
+ from moabb.datasets import Schirrmeister2017
204
+
205
+ __doc__ = _update_moabb_docstring(Schirrmeister2017, doc)
206
+ except ModuleNotFoundError:
207
+ pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
208
+
209
+ def __init__(self, subject_ids):
210
+ super().__init__("Schirrmeister2017", subject_ids=subject_ids)
211
+
212
+
213
+ @deprecated(
214
+ "`BNCI2014001` was renamed to `BNCI2014_001` in v1.13; this alias will be removed in v1.14."
215
+ )
216
+ class BNCI2014001(BNCI2014_001):
217
+ """Deprecated alias for BNCI2014001."""
218
+
219
+ pass
@@ -0,0 +1,313 @@
1
+ """
2
+ Dataset classes for the NMT EEG Corpus dataset.
3
+
4
+ The NMT Scalp EEG Dataset is an open-source annotated dataset of healthy and
5
+ pathological EEG recordings for predictive modeling. This dataset contains
6
+ 2,417 recordings from unique participants spanning almost 625 h.
7
+
8
+ Note:
9
+ - The signal unit may not be uV and further examination is required.
10
+ - The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
11
+ which needs to be further determined.
12
+ """
13
+
14
+ # Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
15
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
16
+ #
17
+ # License: BSD (3-clause)
18
+
19
+ from __future__ import annotations
20
+
21
+ import glob
22
+ import os
23
+ import warnings
24
+ from pathlib import Path
25
+ from unittest import mock
26
+
27
+ import mne
28
+ import numpy as np
29
+ import pandas as pd
30
+ from joblib import Parallel, delayed
31
+ from mne.datasets import fetch_dataset
32
+
33
+ from braindecode.datasets.base import BaseConcatDataset, RawDataset
34
+ from braindecode.datasets.utils import _correct_dataset_path
35
+
36
+ NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
37
+ NMT_archive_name = "NMT.zip"
38
+ NMT_folder_name = "MNE-NMT-eeg-dataset"
39
+ NMT_dataset_name = "NMT-EEG-Corpus"
40
+
41
+ NMT_dataset_params = {
42
+ "dataset_name": NMT_dataset_name,
43
+ "url": NMT_URL,
44
+ "archive_name": NMT_archive_name,
45
+ "folder_name": NMT_folder_name,
46
+ "hash": "77b3ce12bcaf6c6cce4e6690ea89cb22bed55af10c525077b430f6e1d2e3c6bf",
47
+ "config_key": NMT_dataset_name,
48
+ }
49
+
50
+
51
+ class NMT(BaseConcatDataset):
52
+ """The NMT Scalp EEG Dataset.
53
+
54
+ An Open-Source Annotated Dataset of Healthy and Pathological EEG
55
+ Recordings for Predictive Modeling.
56
+
57
+ This dataset contains 2,417 recordings from unique participants spanning
58
+ almost 625 h.
59
+
60
+ Here, the dataset can be used for three tasks, brain-age, gender prediction,
61
+ abnormality detection.
62
+
63
+ The dataset is described in [Khan2022]_.
64
+
65
+ .. versionadded:: 0.9
66
+
67
+ Parameters
68
+ ----------
69
+ path : str
70
+ Parent directory of the dataset.
71
+ recording_ids : list(int) | int
72
+ A (list of) int of recording id(s) to be read (order matters and will
73
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
74
+ then the first recording returned by this class will be chronologically
75
+ later than the second recording. Provide recording_ids in ascending
76
+ order to preserve chronological order.).
77
+ target_name : str
78
+ Can be "pathological", "gender", or "age".
79
+ preload : bool
80
+ If True, preload the data of the Raw objects.
81
+
82
+ References
83
+ ----------
84
+ .. [Khan2022] Khan, H.A.,Ul Ain, R., Kamboh, A.M., Butt, H.T.,Shafait,S.,
85
+ Alamgir, W., Stricker, D. and Shafait, F., 2022. The NMT scalp EEG
86
+ dataset: an open-source annotated dataset of healthy and pathological
87
+ EEG recordings for predictive modeling. Frontiers in neuroscience,
88
+ 15, p.755817.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ path=None,
94
+ target_name="pathological",
95
+ recording_ids=None,
96
+ preload=False,
97
+ n_jobs=1,
98
+ ):
99
+ # Convert empty string to None for consistency
100
+ if path == "":
101
+ path = None
102
+
103
+ # Download dataset if not present
104
+ if path is None:
105
+ path = fetch_dataset(
106
+ dataset_params=NMT_dataset_params,
107
+ path=None,
108
+ processor="unzip",
109
+ force_update=False,
110
+ )
111
+ # First time we fetch the dataset, we need to move the files to the
112
+ # correct directory.
113
+ path = _correct_dataset_path(
114
+ path, NMT_archive_name, "nmt_scalp_eeg_dataset"
115
+ )
116
+ else:
117
+ # Validate that the provided path is a valid NMT dataset
118
+ if not Path(f"{path}/Labels.csv").exists():
119
+ raise ValueError(
120
+ f"The provided path {path} does not contain a valid "
121
+ "NMT dataset (missing Labels.csv). Please ensure the "
122
+ "path points directly to the NMT dataset directory."
123
+ )
124
+ path = _correct_dataset_path(
125
+ path, NMT_archive_name, "nmt_scalp_eeg_dataset"
126
+ )
127
+
128
+ # Get all file paths
129
+ file_paths = glob.glob(
130
+ os.path.join(path, "**" + os.sep + "*.edf"), recursive=True
131
+ )
132
+
133
+ # sort by subject id
134
+ file_paths = [
135
+ file_path
136
+ for file_path in file_paths
137
+ if os.path.splitext(file_path)[1] == ".edf"
138
+ ]
139
+
140
+ # sort by subject id
141
+ file_paths = sorted(
142
+ file_paths, key=lambda p: int(os.path.splitext(p)[0].split(os.sep)[-1])
143
+ )
144
+ if recording_ids is not None:
145
+ file_paths = [file_paths[rec_id] for rec_id in recording_ids]
146
+
147
+ # read labels and rearrange them to match TUH Abnormal EEG Corpus
148
+ description = pd.read_csv(
149
+ os.path.join(path, "Labels.csv"), index_col="recordname"
150
+ )
151
+ if recording_ids is not None:
152
+ # Match metadata by record name instead of position to fix alignment bug
153
+ # when CSV order differs from sorted file order
154
+ selected_recordnames = [os.path.basename(fp) for fp in file_paths]
155
+ description = description.loc[selected_recordnames]
156
+ description.replace(
157
+ {
158
+ "not specified": "X",
159
+ "female": "F",
160
+ "male": "M",
161
+ "abnormal": True,
162
+ "normal": False,
163
+ },
164
+ inplace=True,
165
+ )
166
+ description.rename(columns={"label": "pathological"}, inplace=True)
167
+ description.reset_index(drop=True, inplace=True)
168
+ description["path"] = file_paths
169
+ description = description[["path", "pathological", "age", "gender"]]
170
+
171
+ if n_jobs == 1:
172
+ base_datasets = [
173
+ self._create_dataset(d, target_name, preload)
174
+ for recording_id, d in description.iterrows()
175
+ ]
176
+ else:
177
+ base_datasets = Parallel(n_jobs)(
178
+ delayed(self._create_dataset)(d, target_name, preload)
179
+ for recording_id, d in description.iterrows()
180
+ )
181
+
182
+ super().__init__(base_datasets)
183
+
184
+ @staticmethod
185
+ def _create_dataset(d, target_name, preload):
186
+ raw = mne.io.read_raw_edf(d.path, preload=preload)
187
+ d["n_samples"] = raw.n_times
188
+ d["sfreq"] = raw.info["sfreq"]
189
+ d["train"] = "train" in d.path.split(os.sep)
190
+ base_dataset = RawDataset(raw, d, target_name)
191
+ return base_dataset
192
+
193
+
194
+ def _get_header(*args):
195
+ all_paths = {**_NMT_PATHS}
196
+ return all_paths[args[0]]
197
+
198
+
199
+ def _fake_pd_read_csv(*args, **kwargs):
200
+ # Create a list of lists to hold the data
201
+ # Updated to match the file IDs from _NMT_PATHS (0000036-0000042)
202
+ # to align with the mocked glob.glob return value
203
+ data = [
204
+ ["0000036.edf", "normal", 35, "male", "train"],
205
+ ["0000037.edf", "abnormal", 28, "female", "test"],
206
+ ["0000038.edf", "normal", 62, "male", "train"],
207
+ ["0000039.edf", "abnormal", 41, "female", "test"],
208
+ ["0000040.edf", "normal", 19, "male", "train"],
209
+ ["0000041.edf", "abnormal", 55, "female", "test"],
210
+ ["0000042.edf", "normal", 71, "male", "train"],
211
+ ]
212
+
213
+ # Create the DataFrame, specifying column names
214
+ df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
215
+
216
+ # Set recordname as index to match the real pd.read_csv behavior with index_col="recordname"
217
+ df.set_index("recordname", inplace=True)
218
+
219
+ return df
220
+
221
+
222
+ def _fake_raw(*args, **kwargs):
223
+ sfreq = 10
224
+ ch_names = [
225
+ "EEG A1-REF",
226
+ "EEG A2-REF",
227
+ "EEG FP1-REF",
228
+ "EEG FP2-REF",
229
+ "EEG F3-REF",
230
+ "EEG F4-REF",
231
+ "EEG C3-REF",
232
+ "EEG C4-REF",
233
+ "EEG P3-REF",
234
+ "EEG P4-REF",
235
+ "EEG O1-REF",
236
+ "EEG O2-REF",
237
+ "EEG F7-REF",
238
+ "EEG F8-REF",
239
+ "EEG T3-REF",
240
+ "EEG T4-REF",
241
+ "EEG T5-REF",
242
+ "EEG T6-REF",
243
+ "EEG FZ-REF",
244
+ "EEG CZ-REF",
245
+ "EEG PZ-REF",
246
+ ]
247
+ duration_min = 6
248
+ data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
249
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
250
+ raw = mne.io.RawArray(data=data, info=info)
251
+ return raw
252
+
253
+
254
+ _NMT_PATHS = {
255
+ # these are actual file paths and edf headers from NMT EEG Corpus
256
+ "nmt_scalp_eeg_dataset/abnormal/train/0000036.edf": b"0 0000036 M 13-May-1951 0000036 Age:32 ",
257
+ # noqa E501
258
+ "nmt_scalp_eeg_dataset/abnormal/eval/0000037.edf": b"0 0000037 M 13-May-1951 0000037 Age:32 ",
259
+ # noqa E501
260
+ "nmt_scalp_eeg_dataset/abnormal/eval/0000038.edf": b"0 0000038 M 13-May-1951 0000038 Age:32 ",
261
+ # noqa E501
262
+ "nmt_scalp_eeg_dataset/normal/train/0000039.edf": b"0 0000039 M 13-May-1951 0000039 Age:32 ",
263
+ # noqa E501
264
+ "nmt_scalp_eeg_dataset/normal/eval/0000040.edf": b"0 0000040 M 13-May-1951 0000040 Age:32 ",
265
+ # noqa E501
266
+ "nmt_scalp_eeg_dataset/normal/eval/0000041.edf": b"0 0000041 M 13-May-1951 0000041 Age:32 ",
267
+ # noqa E501
268
+ "nmt_scalp_eeg_dataset/abnormal/train/0000042.edf": b"0 0000042 M 13-May-1951 0000042 Age:32 ",
269
+ # noqa E501
270
+ "Labels.csv": b"0 recordname,label,age,gender,loc 1 0000001.edf,normal,22,not specified,train ",
271
+ # noqa E501
272
+ }
273
+
274
+
275
+ class _NMTMock(NMT):
276
+ """Mocked class for testing and examples."""
277
+
278
+ @mock.patch("pathlib.Path.exists", return_value=True)
279
+ @mock.patch("braindecode.datasets.nmt._correct_dataset_path")
280
+ @mock.patch("mne.datasets.fetch_dataset")
281
+ @mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
282
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
283
+ @mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
284
+ def __init__(
285
+ self,
286
+ mock_glob,
287
+ mock_fetch,
288
+ mock_correct_path,
289
+ mock_path_exists,
290
+ path,
291
+ recording_ids=None,
292
+ target_name="pathological",
293
+ preload=False,
294
+ n_jobs=1,
295
+ ):
296
+ # Prevent download by providing a dummy path if empty/None
297
+ if not path:
298
+ path = "mocked_nmt_path"
299
+
300
+ # Mock fetch_dataset to return a valid path without downloading
301
+ mock_fetch.return_value = path
302
+ # Mock _correct_dataset_path to return the path as-is
303
+ mock_correct_path.side_effect = lambda p, *args, **kwargs: p
304
+
305
+ with warnings.catch_warnings():
306
+ warnings.filterwarnings("ignore", message="Cannot save date file")
307
+ super().__init__(
308
+ path=path,
309
+ recording_ids=recording_ids,
310
+ target_name=target_name,
311
+ preload=preload,
312
+ n_jobs=n_jobs,
313
+ )