braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,18 +3,28 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
 
6
+ from __future__ import annotations
7
+
8
+ import mne
6
9
  import numpy as np
7
10
  import pandas as pd
8
- import mne
9
11
 
10
- from .base import BaseDataset, BaseConcatDataset, WindowsDataset
12
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
11
13
 
12
14
 
13
15
  def create_from_mne_raw(
14
- raws, trial_start_offset_samples, trial_stop_offset_samples,
15
- window_size_samples, window_stride_samples, drop_last_window,
16
- descriptions=None, mapping=None, preload=False, drop_bad_windows=True,
17
- accepted_bads_ratio=0.0):
16
+ raws: list[mne.io.BaseRaw],
17
+ trial_start_offset_samples: int,
18
+ trial_stop_offset_samples: int,
19
+ window_size_samples: int,
20
+ window_stride_samples: int,
21
+ drop_last_window: bool,
22
+ descriptions: list[dict | pd.Series] | None = None,
23
+ mapping: dict[str, int] | None = None,
24
+ preload: bool = False,
25
+ drop_bad_windows: bool = True,
26
+ accepted_bads_ratio: float = 0.0,
27
+ ) -> BaseConcatDataset:
18
28
  """Create WindowsDatasets from mne.RawArrays
19
29
 
20
30
  Parameters
@@ -58,13 +68,16 @@ def create_from_mne_raw(
58
68
  """
59
69
  # Prevent circular import
60
70
  from ..preprocessing.windowers import create_windows_from_events
71
+
61
72
  if descriptions is not None:
62
73
  if len(descriptions) != len(raws):
63
74
  raise ValueError(
64
75
  f"length of 'raws' ({len(raws)}) and 'description' "
65
- f"({len(descriptions)}) has to match")
66
- base_datasets = [BaseDataset(raw, desc) for raw, desc in
67
- zip(raws, descriptions)]
76
+ f"({len(descriptions)}) has to match"
77
+ )
78
+ base_datasets = [
79
+ BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
80
+ ]
68
81
  else:
69
82
  base_datasets = [BaseDataset(raw) for raw in raws]
70
83
 
@@ -84,8 +97,12 @@ def create_from_mne_raw(
84
97
  return windows_datasets
85
98
 
86
99
 
87
- def create_from_mne_epochs(list_of_epochs, window_size_samples,
88
- window_stride_samples, drop_last_window):
100
+ def create_from_mne_epochs(
101
+ list_of_epochs: list[mne.BaseEpochs],
102
+ window_size_samples: int,
103
+ window_stride_samples: int,
104
+ drop_last_window: bool,
105
+ ) -> BaseConcatDataset:
89
106
  """Create WindowsDatasets from mne.Epochs
90
107
 
91
108
  Parameters
@@ -108,8 +125,8 @@ def create_from_mne_epochs(list_of_epochs, window_size_samples,
108
125
  """
109
126
  # Prevent circular import
110
127
  from ..preprocessing.windowers import _check_windowing_arguments
111
- _check_windowing_arguments(0, 0, window_size_samples,
112
- window_stride_samples)
128
+
129
+ _check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
113
130
 
114
131
  list_of_windows_ds = []
115
132
  for epochs in list_of_epochs:
@@ -124,24 +141,28 @@ def create_from_mne_epochs(list_of_epochs, window_size_samples,
124
141
  # if last window does not end at trial stop, make it stop there
125
142
  starts = np.append(starts, stop)
126
143
 
127
- fake_events = [[start, window_size_samples, -1] for start in
128
- starts]
144
+ fake_events = [[start, window_size_samples, -1] for start in starts]
129
145
 
130
146
  for trial_i, trial in enumerate(epochs):
131
- metadata = pd.DataFrame({
132
- 'i_window_in_trial': np.arange(len(fake_events)),
133
- 'i_start_in_trial': starts + original_trial_starts[trial_i],
134
- 'i_stop_in_trial': starts + original_trial_starts[
135
- trial_i] + window_size_samples,
136
- 'target': len(fake_events) * [event_descriptions[trial_i]]
137
- })
147
+ metadata = pd.DataFrame(
148
+ {
149
+ "i_window_in_trial": np.arange(len(fake_events)),
150
+ "i_start_in_trial": starts + original_trial_starts[trial_i],
151
+ "i_stop_in_trial": starts
152
+ + original_trial_starts[trial_i]
153
+ + window_size_samples,
154
+ "target": len(fake_events) * [event_descriptions[trial_i]],
155
+ }
156
+ )
138
157
  # window size - 1, since tmax is inclusive
139
158
  mne_epochs = mne.Epochs(
140
- mne.io.RawArray(trial, epochs.info), fake_events,
159
+ mne.io.RawArray(trial, epochs.info),
160
+ fake_events,
141
161
  baseline=None,
142
162
  tmin=0,
143
163
  tmax=(window_size_samples - 1) / epochs.info["sfreq"],
144
- metadata=metadata)
164
+ metadata=metadata,
165
+ )
145
166
 
146
167
  mne_epochs.drop_bad(reject=None, flat=None)
147
168
 
@@ -1,5 +1,4 @@
1
- """Dataset objects for some public datasets.
2
- """
1
+ """Dataset objects for some public datasets."""
3
2
 
4
3
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
5
4
  # Lukas Gemein <l.gemein@gmail.com>
@@ -9,16 +8,23 @@
9
8
  #
10
9
  # License: BSD (3-clause)
11
10
 
12
- import pandas as pd
11
+ from __future__ import annotations
12
+
13
+ import warnings
14
+ from typing import Any
15
+
13
16
  import mne
17
+ import pandas as pd
14
18
 
15
- from .base import BaseDataset, BaseConcatDataset
16
19
  from braindecode.util import _update_moabb_docstring
17
20
 
21
+ from .base import BaseConcatDataset, BaseDataset
22
+
18
23
 
19
24
  def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
20
25
  # soft dependency on moabb
21
26
  from moabb.datasets.utils import dataset_list
27
+
22
28
  for dataset in dataset_list:
23
29
  if dataset_name == dataset.__name__:
24
30
  # return an instance of the found dataset class
@@ -29,35 +35,41 @@ def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
29
35
  raise ValueError(f"{dataset_name} not found in moabb datasets")
30
36
 
31
37
 
32
- def _fetch_and_unpack_moabb_data(dataset, subject_ids):
33
- data = dataset.get_data(subject_ids)
38
+ def _fetch_and_unpack_moabb_data(dataset, subject_ids=None, dataset_load_kwargs=None):
39
+ if dataset_load_kwargs is None:
40
+ data = dataset.get_data(subject_ids)
41
+ else:
42
+ data = dataset.get_data(subjects=subject_ids, **dataset_load_kwargs)
43
+
34
44
  raws, subject_ids, session_ids, run_ids = [], [], [], []
35
45
  for subj_id, subj_data in data.items():
36
46
  for sess_id, sess_data in subj_data.items():
37
47
  for run_id, raw in sess_data.items():
38
- # set annotation if empty
39
- if len(raw.annotations) == 0:
40
- annots = _annotations_from_moabb_stim_channel(raw, dataset)
41
- raw.set_annotations(annots)
48
+ annots = _annotations_from_moabb_stim_channel(raw, dataset)
49
+ raw.set_annotations(annots)
42
50
  raws.append(raw)
43
51
  subject_ids.append(subj_id)
44
52
  session_ids.append(sess_id)
45
53
  run_ids.append(run_id)
46
- description = pd.DataFrame({
47
- 'subject': subject_ids,
48
- 'session': session_ids,
49
- 'run': run_ids
50
- })
54
+ description = pd.DataFrame(
55
+ {"subject": subject_ids, "session": session_ids, "run": run_ids}
56
+ )
51
57
  return raws, description
52
58
 
53
59
 
54
60
  def _annotations_from_moabb_stim_channel(raw, dataset):
55
- # find events from stim channel
56
- events = mne.find_events(raw)
61
+ # find events from the stim channel
62
+ stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
63
+ if len(stim_channels) > 0:
64
+ # returns an empty array if none found
65
+ events = mne.find_events(raw, shortest_event=0, verbose=False)
66
+ event_id = dataset.event_id
67
+ else:
68
+ events, event_id = mne.events_from_annotations(raw, verbose=False)
57
69
 
58
70
  # get annotations from events
59
- event_desc = {k: v for v, k in dataset.event_id.items()}
60
- annots = mne.annotations_from_events(events, raw.info['sfreq'], event_desc)
71
+ event_desc = {k: v for v, k in event_id.items()}
72
+ annots = mne.annotations_from_events(events, raw.info["sfreq"], event_desc)
61
73
 
62
74
  # set trial on and offset given by moabb
63
75
  onset, offset = dataset.interval
@@ -66,28 +78,47 @@ def _annotations_from_moabb_stim_channel(raw, dataset):
66
78
  return annots
67
79
 
68
80
 
69
- def fetch_data_with_moabb(dataset_name, subject_ids, dataset_kwargs=None):
81
+ def fetch_data_with_moabb(
82
+ dataset_name: str,
83
+ subject_ids: list[int] | int | None = None,
84
+ dataset_kwargs: dict[str, Any] | None = None,
85
+ dataset_load_kwargs: dict[str, Any] | None = None,
86
+ ) -> tuple[list[mne.io.Raw], pd.DataFrame]:
70
87
  # ToDo: update path to where moabb downloads / looks for the data
71
88
  """Fetch data using moabb.
72
89
 
73
90
  Parameters
74
91
  ----------
75
- dataset_name: str
92
+ dataset_name: str | moabb.datasets.base.BaseDataset
76
93
  the name of a dataset included in moabb
77
94
  subject_ids: list(int) | int
78
95
  (list of) int of subject(s) to be fetched
79
96
  dataset_kwargs: dict, optional
80
97
  optional dictionary containing keyword arguments
81
98
  to pass to the moabb dataset when instantiating it.
99
+ data_load_kwargs: dict, optional
100
+ optional dictionary containing keyword arguments
101
+ to pass to the moabb dataset's load_data method.
102
+ Allows using the moabb cache_config=None and
103
+ process_pipeline=None.
82
104
 
83
105
  Returns
84
106
  -------
85
107
  raws: mne.Raw
86
108
  info: pandas.DataFrame
87
109
  """
88
- dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
110
+ if isinstance(dataset_name, str):
111
+ dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
112
+ else:
113
+ from moabb.datasets.base import BaseDataset
114
+
115
+ if isinstance(dataset_name, BaseDataset):
116
+ dataset = dataset_name
117
+
89
118
  subject_id = [subject_ids] if isinstance(subject_ids, int) else subject_ids
90
- return _fetch_and_unpack_moabb_data(dataset, subject_id)
119
+ return _fetch_and_unpack_moabb_data(
120
+ dataset, subject_id, dataset_load_kwargs=dataset_load_kwargs
121
+ )
91
122
 
92
123
 
93
124
  class MOABBDataset(BaseConcatDataset):
@@ -103,11 +134,38 @@ class MOABBDataset(BaseConcatDataset):
103
134
  dataset_kwargs: dict, optional
104
135
  optional dictionary containing keyword arguments
105
136
  to pass to the moabb dataset when instantiating it.
137
+ dataset_load_kwargs: dict, optional
138
+ optional dictionary containing keyword arguments
139
+ to pass to the moabb dataset's load_data method.
140
+ Allows using the moabb cache_config=None and
141
+ process_pipeline=None.
106
142
  """
107
- def __init__(self, dataset_name, subject_ids, dataset_kwargs=None):
108
- raws, description = fetch_data_with_moabb(dataset_name, subject_ids, dataset_kwargs)
109
- all_base_ds = [BaseDataset(raw, row)
110
- for raw, (_, row) in zip(raws, description.iterrows())]
143
+
144
+ def __init__(
145
+ self,
146
+ dataset_name: str,
147
+ subject_ids: list[int] | int | None = None,
148
+ dataset_kwargs: dict[str, Any] | None = None,
149
+ dataset_load_kwargs: dict[str, Any] | None = None,
150
+ ):
151
+ # soft dependency on moabb
152
+ from moabb import __version__ as moabb_version # type: ignore
153
+
154
+ if moabb_version == "1.0.0":
155
+ warnings.warn(
156
+ "moabb version 1.0.0 generates incorrect annotations. "
157
+ "Please update to another version, version 0.5 or 1.1.0 "
158
+ )
159
+
160
+ raws, description = fetch_data_with_moabb(
161
+ dataset_name,
162
+ subject_ids,
163
+ dataset_kwargs,
164
+ dataset_load_kwargs=dataset_load_kwargs,
165
+ )
166
+ all_base_ds = [
167
+ BaseDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
168
+ ]
111
169
  super().__init__(all_base_ds)
112
170
 
113
171
 
@@ -122,6 +180,7 @@ class BNCI2014001(MOABBDataset):
122
180
  """
123
181
  try:
124
182
  from moabb.datasets import BNCI2014001
183
+
125
184
  __doc__ = _update_moabb_docstring(BNCI2014001, doc)
126
185
  except ModuleNotFoundError:
127
186
  pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
@@ -141,6 +200,7 @@ class HGD(MOABBDataset):
141
200
  """
142
201
  try:
143
202
  from moabb.datasets import Schirrmeister2017
203
+
144
204
  __doc__ = _update_moabb_docstring(Schirrmeister2017, doc)
145
205
  except ModuleNotFoundError:
146
206
  pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
@@ -0,0 +1,311 @@
1
+ """
2
+ Dataset classes for the NMT EEG Corpus dataset.
3
+
4
+ The NMT Scalp EEG Dataset is an open-source annotated dataset of healthy and
5
+ pathological EEG recordings for predictive modeling. This dataset contains
6
+ 2,417 recordings from unique participants spanning almost 625 h.
7
+
8
+ Note:
9
+ - The signal unit may not be uV and further examination is required.
10
+ - The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
11
+ which needs to be further determined.
12
+
13
+ """
14
+
15
+ # Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
16
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
17
+ #
18
+ # License: BSD (3-clause)
19
+
20
+ from __future__ import annotations
21
+
22
+ import glob
23
+ import os
24
+ import warnings
25
+ from pathlib import Path
26
+ from unittest import mock
27
+
28
+ import mne
29
+ import numpy as np
30
+ import pandas as pd
31
+ from joblib import Parallel, delayed
32
+ from mne.datasets import fetch_dataset
33
+
34
+ from braindecode.datasets.base import BaseConcatDataset, BaseDataset
35
+
36
+ NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
37
+ NMT_archive_name = "NMT.zip"
38
+ NMT_folder_name = "MNE-NMT-eeg-dataset"
39
+ NMT_dataset_name = "NMT-EEG-Corpus"
40
+
41
+ NMT_dataset_params = {
42
+ "dataset_name": NMT_dataset_name,
43
+ "url": NMT_URL,
44
+ "archive_name": NMT_archive_name,
45
+ "folder_name": NMT_folder_name,
46
+ "hash": "77b3ce12bcaf6c6cce4e6690ea89cb22bed55af10c525077b430f6e1d2e3c6bf",
47
+ "config_key": NMT_dataset_name,
48
+ }
49
+
50
+
51
+ class NMT(BaseConcatDataset):
52
+ """The NMT Scalp EEG Dataset.
53
+
54
+ An Open-Source Annotated Dataset of Healthy and Pathological EEG
55
+ Recordings for Predictive Modeling.
56
+
57
+ This dataset contains 2,417 recordings from unique participants spanning
58
+ almost 625 h.
59
+
60
+ Here, the dataset can be used for three tasks, brain-age, gender prediction,
61
+ abnormality detection.
62
+
63
+ The dataset is described in [Khan2022]_.
64
+
65
+ .. versionadded:: 0.9
66
+
67
+ Parameters
68
+ ----------
69
+ path: str
70
+ Parent directory of the dataset.
71
+ recording_ids: list(int) | int
72
+ A (list of) int of recording id(s) to be read (order matters and will
73
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
74
+ then the first recording returned by this class will be chronologically
75
+ later than the second recording. Provide recording_ids in ascending
76
+ order to preserve chronological order.).
77
+ target_name: str
78
+ Can be "pathological", "gender", or "age".
79
+ preload: bool
80
+ If True, preload the data of the Raw objects.
81
+
82
+ References
83
+ ----------
84
+ .. [Khan2022] Khan, H.A.,Ul Ain, R., Kamboh, A.M., Butt, H.T.,Shafait,S.,
85
+ Alamgir, W., Stricker, D. and Shafait, F., 2022. The NMT scalp EEG
86
+ dataset: an open-source annotated dataset of healthy and pathological
87
+ EEG recordings for predictive modeling. Frontiers in neuroscience,
88
+ 15, p.755817.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ path=None,
94
+ target_name="pathological",
95
+ recording_ids=None,
96
+ preload=False,
97
+ n_jobs=1,
98
+ ):
99
+ # correct the path if needed
100
+ if path is not None:
101
+ list_csv = glob.glob(f"{path}/**/Labels.csv", recursive=True)
102
+ if isinstance(list_csv, list) and len(list_csv) > 0:
103
+ path = Path(list_csv[0]).parent
104
+
105
+ if path is None or len(list_csv) == 0:
106
+ path = fetch_dataset(
107
+ dataset_params=NMT_dataset_params,
108
+ path=Path(path) if path is not None else None,
109
+ processor="unzip",
110
+ force_update=False,
111
+ )
112
+ # First time we fetch the dataset, we need to move the files to the
113
+ # correct directory.
114
+ path = _correct_path(path)
115
+
116
+ # Get all file paths
117
+ file_paths = glob.glob(
118
+ os.path.join(path, "**" + os.sep + "*.edf"), recursive=True
119
+ )
120
+
121
+ # sort by subject id
122
+ file_paths = [
123
+ file_path
124
+ for file_path in file_paths
125
+ if os.path.splitext(file_path)[1] == ".edf"
126
+ ]
127
+
128
+ # sort by subject id
129
+ file_paths = sorted(
130
+ file_paths, key=lambda p: int(os.path.splitext(p)[0].split(os.sep)[-1])
131
+ )
132
+ if recording_ids is not None:
133
+ file_paths = [file_paths[rec_id] for rec_id in recording_ids]
134
+
135
+ # read labels and rearrange them to match TUH Abnormal EEG Corpus
136
+ description = pd.read_csv(
137
+ os.path.join(path, "Labels.csv"), index_col="recordname"
138
+ )
139
+ if recording_ids is not None:
140
+ description = description.iloc[recording_ids]
141
+ description.replace(
142
+ {
143
+ "not specified": "X",
144
+ "female": "F",
145
+ "male": "M",
146
+ "abnormal": True,
147
+ "normal": False,
148
+ },
149
+ inplace=True,
150
+ )
151
+ description.rename(columns={"label": "pathological"}, inplace=True)
152
+ description.reset_index(drop=True, inplace=True)
153
+ description["path"] = file_paths
154
+ description = description[["path", "pathological", "age", "gender"]]
155
+
156
+ if n_jobs == 1:
157
+ base_datasets = [
158
+ self._create_dataset(d, target_name, preload)
159
+ for recording_id, d in description.iterrows()
160
+ ]
161
+ else:
162
+ base_datasets = Parallel(n_jobs)(
163
+ delayed(self._create_dataset)(d, target_name, preload)
164
+ for recording_id, d in description.iterrows()
165
+ )
166
+
167
+ super().__init__(base_datasets)
168
+
169
+ @staticmethod
170
+ def _create_dataset(d, target_name, preload):
171
+ raw = mne.io.read_raw_edf(d.path, preload=preload)
172
+ d["n_samples"] = raw.n_times
173
+ d["sfreq"] = raw.info["sfreq"]
174
+ d["train"] = "train" in d.path.split(os.sep)
175
+ base_dataset = BaseDataset(raw, d, target_name)
176
+ return base_dataset
177
+
178
+
179
+ def _correct_path(path: str):
180
+ """
181
+ Check if the path is correct and rename the file if needed.
182
+
183
+ Parameters
184
+ ----------
185
+ path: basestring
186
+ Path to the file.
187
+
188
+ Returns
189
+ -------
190
+ path: basestring
191
+ Corrected path.
192
+ """
193
+ if not Path(path).exists():
194
+ unzip_file_name = f"{NMT_archive_name}.unzip"
195
+ if (Path(path).parent / unzip_file_name).exists():
196
+ try:
197
+ os.rename(
198
+ src=Path(path).parent / unzip_file_name,
199
+ dst=Path(path),
200
+ )
201
+
202
+ except PermissionError:
203
+ raise PermissionError(
204
+ f"Please rename {Path(path).parent / unzip_file_name}"
205
+ + f"manually to {path} and try again."
206
+ )
207
+ path = os.path.join(path, "nmt_scalp_eeg_dataset")
208
+
209
+ return path
210
+
211
+
212
+ def _get_header(*args):
213
+ all_paths = {**_NMT_PATHS}
214
+ return all_paths[args[0]]
215
+
216
+
217
+ def _fake_pd_read_csv(*args, **kwargs):
218
+ # Create a list of lists to hold the data
219
+ data = [
220
+ ["0000001.edf", "normal", 35, "male", "train"],
221
+ ["0000002.edf", "abnormal", 28, "female", "test"],
222
+ ["0000003.edf", "normal", 62, "male", "train"],
223
+ ["0000004.edf", "abnormal", 41, "female", "test"],
224
+ ["0000005.edf", "normal", 19, "male", "train"],
225
+ ["0000006.edf", "abnormal", 55, "female", "test"],
226
+ ["0000007.edf", "normal", 71, "male", "train"],
227
+ ]
228
+
229
+ # Create the DataFrame, specifying column names
230
+ df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
231
+
232
+ return df
233
+
234
+
235
+ def _fake_raw(*args, **kwargs):
236
+ sfreq = 10
237
+ ch_names = [
238
+ "EEG A1-REF",
239
+ "EEG A2-REF",
240
+ "EEG FP1-REF",
241
+ "EEG FP2-REF",
242
+ "EEG F3-REF",
243
+ "EEG F4-REF",
244
+ "EEG C3-REF",
245
+ "EEG C4-REF",
246
+ "EEG P3-REF",
247
+ "EEG P4-REF",
248
+ "EEG O1-REF",
249
+ "EEG O2-REF",
250
+ "EEG F7-REF",
251
+ "EEG F8-REF",
252
+ "EEG T3-REF",
253
+ "EEG T4-REF",
254
+ "EEG T5-REF",
255
+ "EEG T6-REF",
256
+ "EEG FZ-REF",
257
+ "EEG CZ-REF",
258
+ "EEG PZ-REF",
259
+ ]
260
+ duration_min = 6
261
+ data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
262
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
263
+ raw = mne.io.RawArray(data=data, info=info)
264
+ return raw
265
+
266
+
267
+ _NMT_PATHS = {
268
+ # these are actual file paths and edf headers from NMT EEG Corpus
269
+ "nmt_scalp_eeg_dataset/abnormal/train/0000036.edf": b"0 0000036 M 13-May-1951 0000036 Age:32 ",
270
+ # noqa E501
271
+ "nmt_scalp_eeg_dataset/abnormal/eval/0000037.edf": b"0 0000037 M 13-May-1951 0000037 Age:32 ",
272
+ # noqa E501
273
+ "nmt_scalp_eeg_dataset/abnormal/eval/0000038.edf": b"0 0000038 M 13-May-1951 0000038 Age:32 ",
274
+ # noqa E501
275
+ "nmt_scalp_eeg_dataset/normal/train/0000039.edf": b"0 0000039 M 13-May-1951 0000039 Age:32 ",
276
+ # noqa E501
277
+ "nmt_scalp_eeg_dataset/normal/eval/0000040.edf": b"0 0000040 M 13-May-1951 0000040 Age:32 ",
278
+ # noqa E501
279
+ "nmt_scalp_eeg_dataset/normal/eval/0000041.edf": b"0 0000041 M 13-May-1951 0000041 Age:32 ",
280
+ # noqa E501
281
+ "nmt_scalp_eeg_dataset/abnormal/train/0000042.edf": b"0 0000042 M 13-May-1951 0000042 Age:32 ",
282
+ # noqa E501
283
+ "Labels.csv": b"0 recordname,label,age,gender,loc 1 0000001.edf,normal,22,not specified,train ",
284
+ # noqa E501
285
+ }
286
+
287
+
288
+ class _NMTMock(NMT):
289
+ """Mocked class for testing and examples."""
290
+
291
+ @mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
292
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
293
+ @mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
294
+ def __init__(
295
+ self,
296
+ mock_glob,
297
+ path,
298
+ recording_ids=None,
299
+ target_name="pathological",
300
+ preload=False,
301
+ n_jobs=1,
302
+ ):
303
+ with warnings.catch_warnings():
304
+ warnings.filterwarnings("ignore", message="Cannot save date file")
305
+ super().__init__(
306
+ path=path,
307
+ recording_ids=recording_ids,
308
+ target_name=target_name,
309
+ preload=preload,
310
+ n_jobs=n_jobs,
311
+ )