braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ # Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
2
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ from __future__ import annotations
7
+
8
+ import glob
9
+ import os
10
+ import os.path as osp
11
+ from os import remove
12
+ from shutil import unpack_archive
13
+
14
+ import mne
15
+ import numpy as np
16
+ from mne.utils import verbose
17
+ from scipy.io import loadmat
18
+
19
+ from braindecode.datasets import BaseConcatDataset, BaseDataset
20
+
21
+ DATASET_URL = (
22
+ "https://stacks.stanford.edu/file/druid:zk881ps0522/"
23
+ "BCI_Competion4_dataset4_data_fingerflexions.zip"
24
+ )
25
+
26
+
27
+ class BCICompetitionIVDataset4(BaseConcatDataset):
28
+ """BCI competition IV dataset 4.
29
+
30
+ Contains ECoG recordings for three patients moving fingers during the experiment.
31
+ Targets correspond to the time courses of the flexion of each of five fingers.
32
+ See http://www.bbci.de/competition/iv/desc_4.pdf and
33
+ http://www.bbci.de/competition/iv/ for the dataset and competition description.
34
+ ECoG library containing the dataset: https://searchworks.stanford.edu/view/zk881ps0522
35
+
36
+ Notes
37
+ -----
38
+ When using this dataset please cite [1]_ .
39
+
40
+ Parameters
41
+ ----------
42
+ subject_ids : list(int) | int | None
43
+ (list of) int of subject(s) to be loaded. If None, load all available
44
+ subjects. Should be in range 1-3.
45
+
46
+ References
47
+ ----------
48
+ .. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
49
+ Nature human behaviour 3, no. 11 (2019): 1225-1235.
50
+ https://doi.org/10.1038/s41562-019-0678-3
51
+ """
52
+
53
+ possible_subjects = [1, 2, 3]
54
+
55
+ def __init__(self, subject_ids: list[int] | int | None = None):
56
+ data_path = self.download()
57
+ if isinstance(subject_ids, int):
58
+ subject_ids = [subject_ids]
59
+ if subject_ids is None:
60
+ subject_ids = self.possible_subjects
61
+ self._validate_subjects(subject_ids)
62
+ files_list = [f"{data_path}/sub{i}_comp.mat" for i in subject_ids]
63
+ datasets = []
64
+ for file_path in files_list:
65
+ raw_train, raw_test = self._load_data_to_mne(file_path)
66
+ desc_train = dict(
67
+ subject=file_path.split("/")[-1].split("sub")[1][0],
68
+ file_name=file_path.split("/")[-1],
69
+ session="train",
70
+ )
71
+ desc_test = dict(
72
+ subject=file_path.split("/")[-1].split("sub")[1][0],
73
+ file_name=file_path.split("/")[-1],
74
+ session="test",
75
+ )
76
+ datasets.append(BaseDataset(raw_train, description=desc_train))
77
+ datasets.append(BaseDataset(raw_test, description=desc_test))
78
+ super().__init__(datasets)
79
+
80
+ @staticmethod
81
+ def download(path=None, force_update=False, verbose=None):
82
+ """Download the dataset.
83
+
84
+ Parameters
85
+ ----------
86
+ path (None | str) – Location of where to look for the data storing location.
87
+ If None, the environment variable or config parameter
88
+ MNE_DATASETS_(dataset)_PATH is used. If it doesn’t exist, the “~/mne_data”
89
+ directory is used. If the dataset is not found under the given path, the data
90
+ will be automatically downloaded to the specified folder.
91
+ force_update (bool) – Force update of the dataset even if a local copy exists.
92
+ verbose (bool, str, int, or None) – If not None, override default verbose level
93
+ (see mne.verbose())
94
+
95
+ Returns
96
+ -------
97
+
98
+ """
99
+ signature = "BCICompetitionIVDataset4"
100
+ folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
101
+ # Check if the dataset already exists (unpacked). We have to do that manually
102
+ # because we are removing .zip file from disk to save disk space.
103
+
104
+ from moabb.datasets.download import get_dataset_path # keep soft dependency
105
+
106
+ path = get_dataset_path(signature, path)
107
+ key_dest = "MNE-{:s}-data".format(signature.lower())
108
+ # We do not use mne _url_to_local_path due to ':' in the url that causes problems on Windows
109
+ destination = osp.join(path, key_dest, folder_name)
110
+ if len(list(glob.glob(osp.join(destination, "*.mat")))) == 6:
111
+ return destination
112
+ data_path = _data_dl(
113
+ DATASET_URL,
114
+ osp.join(destination, folder_name, signature),
115
+ force_update=force_update,
116
+ )
117
+ unpack_archive(data_path, osp.dirname(destination))
118
+ # removes .zip file that the data was unpacked from
119
+ remove(data_path)
120
+ return destination
121
+
122
+ @staticmethod
123
+ def _prepare_targets(upsampled_targets, targets_stride):
124
+ original_targets = np.full_like(upsampled_targets, np.nan)
125
+ original_targets[::targets_stride] = upsampled_targets[::targets_stride]
126
+ return original_targets
127
+
128
+ def _load_data_to_mne(self, file_path):
129
+ data = loadmat(file_path)
130
+ test_labels = loadmat(file_path.replace("comp.mat", "testlabels.mat"))
131
+ train_data = data["train_data"]
132
+ test_data = data["test_data"]
133
+ upsampled_train_targets = data["train_dg"]
134
+ upsampled_test_targets = test_labels["test_dg"]
135
+
136
+ signal_sfreq = 1000
137
+ original_target_sfreq = 25
138
+ targets_stride = int(signal_sfreq / original_target_sfreq)
139
+
140
+ original_targets = self._prepare_targets(
141
+ upsampled_train_targets, targets_stride
142
+ )
143
+ original_test_targets = self._prepare_targets(
144
+ upsampled_test_targets, targets_stride
145
+ )
146
+
147
+ ch_names = [f"{i}" for i in range(train_data.shape[1])]
148
+ ch_names += [f"target_{i}" for i in range(original_targets.shape[1])]
149
+ ch_types = ["ecog" for _ in range(train_data.shape[1])]
150
+ ch_types += ["misc" for _ in range(original_targets.shape[1])]
151
+
152
+ info = mne.create_info(sfreq=signal_sfreq, ch_names=ch_names, ch_types=ch_types)
153
+ info["temp"] = dict(target_sfreq=original_target_sfreq)
154
+ train_data = np.concatenate([train_data, original_targets], axis=1)
155
+ test_data = np.concatenate([test_data, original_test_targets], axis=1)
156
+
157
+ raw_train = mne.io.RawArray(train_data.T, info=info)
158
+ raw_test = mne.io.RawArray(test_data.T, info=info)
159
+ # TODO: show how to resample targets
160
+ return raw_train, raw_test
161
+
162
+ def _validate_subjects(self, subject_ids):
163
+ if isinstance(subject_ids, (list, tuple)):
164
+ if not all((subject in self.possible_subjects for subject in subject_ids)):
165
+ raise ValueError(
166
+ f"Wrong subject_ids parameter. Possible values: {self.possible_subjects}. "
167
+ f"Provided {subject_ids}."
168
+ )
169
+ else:
170
+ raise ValueError(
171
+ "Wrong subject_ids format. Expected types: None, list, tuple, int."
172
+ )
173
+
174
+
175
+ @verbose
176
+ def _data_dl(url, destination, force_update=False, verbose=None):
177
+ # Code taken from moabb due to problem with ':' occurring in path
178
+ # On Windows ':' is a forbidden in folder name
179
+ # moabb/datasets/download.py
180
+
181
+ from pooch import file_hash, retrieve # keep soft dependency
182
+
183
+ if not osp.isfile(destination) or force_update:
184
+ if osp.isfile(destination):
185
+ os.remove(destination)
186
+ if not osp.isdir(osp.dirname(destination)):
187
+ os.makedirs(osp.dirname(destination))
188
+ known_hash = None
189
+ else:
190
+ known_hash = file_hash(destination)
191
+ data_path = retrieve(
192
+ url, known_hash, fname=osp.basename(url), path=osp.dirname(destination)
193
+ )
194
+ return data_path
@@ -0,0 +1,245 @@
1
+ """Dataset for loading BIDS.
2
+
3
+ More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
4
+ """
5
+
6
+ # Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
7
+ #
8
+ # License: BSD (3-clause)
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import mne
17
+ import mne_bids
18
+ import numpy as np
19
+ import pandas as pd
20
+ from joblib import Parallel, delayed
21
+
22
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
23
+
24
+
25
+ def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
26
+ return {
27
+ "path": bids_path.fpath,
28
+ "subject": bids_path.subject,
29
+ "session": bids_path.session,
30
+ "task": bids_path.task,
31
+ "acquisition": bids_path.acquisition,
32
+ "run": bids_path.run,
33
+ "processing": bids_path.processing,
34
+ "recording": bids_path.recording,
35
+ "space": bids_path.space,
36
+ "split": bids_path.split,
37
+ "description": bids_path.description,
38
+ "suffix": bids_path.suffix,
39
+ "extension": bids_path.extension,
40
+ "datatype": bids_path.datatype,
41
+ }
42
+
43
+
44
+ @dataclass
45
+ class BIDSDataset(BaseConcatDataset):
46
+ """Dataset for loading BIDS.
47
+
48
+ This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
49
+ as it will be used to find the files to load. The default ``extensions`` parameter was changed.
50
+
51
+ More information on BIDS (Brain Imaging Data Structure)
52
+ can be found at https://bids.neuroimaging.io
53
+
54
+ .. Note::
55
+ For loading "unofficial" BIDS datasets containing epoched data,
56
+ you can use :class:`BIDSEpochsDataset`.
57
+
58
+ Parameters
59
+ ----------
60
+ root : pathlib.Path | str
61
+ The root of the BIDS path.
62
+ subjects : str | array-like of str | None
63
+ The subject ID. Corresponds to "sub".
64
+ sessions : str | array-like of str | None
65
+ The acquisition session. Corresponds to "ses".
66
+ tasks : str | array-like of str | None
67
+ The experimental task. Corresponds to "task".
68
+ acquisitions: str | array-like of str | None
69
+ The acquisition parameters. Corresponds to "acq".
70
+ runs : str | array-like of str | None
71
+ The run number. Corresponds to "run".
72
+ processings : str | array-like of str | None
73
+ The processing label. Corresponds to "proc".
74
+ recordings : str | array-like of str | None
75
+ The recording name. Corresponds to "rec".
76
+ spaces : str | array-like of str | None
77
+ The coordinate space for anatomical and sensor location
78
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
79
+ Corresponds to "space".
80
+ Note that valid values for ``space`` must come from a list
81
+ of BIDS keywords as described in the BIDS specification.
82
+ splits : str | array-like of str | None
83
+ The split of the continuous recording file for ``.fif`` data.
84
+ Corresponds to "split".
85
+ descriptions : str | array-like of str | None
86
+ This corresponds to the BIDS entity ``desc``. It is used to provide
87
+ additional information for derivative data, e.g., preprocessed data
88
+ may be assigned ``description='cleaned'``.
89
+ suffixes : str | array-like of str | None
90
+ The filename suffix. This is the entity after the
91
+ last ``_`` before the extension. E.g., ``'channels'``.
92
+ The following filename suffix's are accepted:
93
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
94
+ 'participants', 'scans', 'electrodes', 'coordsystem',
95
+ 'channels', 'events', 'headshape', 'digitizer',
96
+ 'beh', 'physio', 'stim'
97
+ extensions : str | array-like of str | None
98
+ The extension of the filename. E.g., ``'.json'``.
99
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
100
+ datatypes : str | array-like of str | None
101
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
102
+ ``'ieeg'``.
103
+ check : bool
104
+ If ``True``, only returns paths that conform to BIDS. If ``False``
105
+ (default), the ``.check`` attribute of the returned
106
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
107
+ do conform to BIDS, and to ``False`` for those that don't.
108
+ preload : bool
109
+ If True, preload the data. Defaults to False.
110
+ n_jobs : int
111
+ Number of jobs to run in parallel. Defaults to 1.
112
+ """
113
+
114
+ root: Path | str
115
+ subjects: str | list[str] | None = None
116
+ sessions: str | list[str] | None = None
117
+ tasks: str | list[str] | None = None
118
+ acquisitions: str | list[str] | None = None
119
+ runs: str | list[str] | None = None
120
+ processings: str | list[str] | None = None
121
+ recordings: str | list[str] | None = None
122
+ spaces: str | list[str] | None = None
123
+ splits: str | list[str] | None = None
124
+ descriptions: str | list[str] | None = None
125
+ suffixes: str | list[str] | None = None
126
+ extensions: str | list[str] | None = field(
127
+ default_factory=lambda: [
128
+ ".con",
129
+ ".sqd",
130
+ ".pdf",
131
+ ".fif",
132
+ ".ds",
133
+ ".vhdr",
134
+ ".set",
135
+ ".edf",
136
+ ".bdf",
137
+ ".EDF",
138
+ ".snirf",
139
+ ".cdt",
140
+ ".mef",
141
+ ".nwb",
142
+ ]
143
+ )
144
+ datatypes: str | list[str] | None = None
145
+ check: bool = False
146
+ preload: bool = False
147
+ n_jobs: int = 1
148
+
149
+ @property
150
+ def _filter_out_epochs(self):
151
+ return True
152
+
153
+ def __post_init__(self):
154
+ bids_paths = mne_bids.find_matching_paths(
155
+ root=self.root,
156
+ subjects=self.subjects,
157
+ sessions=self.sessions,
158
+ tasks=self.tasks,
159
+ acquisitions=self.acquisitions,
160
+ runs=self.runs,
161
+ processings=self.processings,
162
+ recordings=self.recordings,
163
+ spaces=self.spaces,
164
+ splits=self.splits,
165
+ descriptions=self.descriptions,
166
+ suffixes=self.suffixes,
167
+ extensions=self.extensions,
168
+ datatypes=self.datatypes,
169
+ check=self.check,
170
+ )
171
+ # Filter out .json files files:
172
+ # (argument ignore_json only available in mne-bids>=0.16)
173
+ bids_paths = [
174
+ bids_path for bids_path in bids_paths if bids_path.extension != ".json"
175
+ ]
176
+ # Filter out _epo.fif files:
177
+ if self._filter_out_epochs:
178
+ bids_paths = [
179
+ bids_path
180
+ for bids_path in bids_paths
181
+ if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
182
+ ]
183
+
184
+ all_base_ds = Parallel(n_jobs=self.n_jobs)(
185
+ delayed(self._get_dataset)(bids_path) for bids_path in bids_paths
186
+ )
187
+ super().__init__(all_base_ds)
188
+
189
+ def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
190
+ description = _description_from_bids_path(bids_path)
191
+ raw = mne_bids.read_raw_bids(bids_path, verbose=False)
192
+ if self.preload:
193
+ raw.load_data()
194
+ return BaseDataset(raw, description)
195
+
196
+
197
+ class BIDSEpochsDataset(BIDSDataset):
198
+ """**Experimental** dataset for loading :class:`mne.Epochs` organised in BIDS.
199
+
200
+ The files must end with ``_epo.fif``.
201
+
202
+ .. Warning::
203
+ Epoched data is not officially supported in BIDS.
204
+
205
+ .. Note::
206
+ **Parameters:** This class has the same parameters as :class:`BIDSDataset` except
207
+ for arguments ``datatypes``, ``extensions`` and ``check`` which are fixed.
208
+ """
209
+
210
+ @property
211
+ def _filter_out_epochs(self):
212
+ return False
213
+
214
+ def __init__(self, *args, **kwargs):
215
+ super().__init__(
216
+ *args,
217
+ extensions=".fif",
218
+ suffixes="epo",
219
+ check=False,
220
+ **kwargs,
221
+ )
222
+
223
+ def _set_metadata(self, epochs: mne.BaseEpochs) -> None:
224
+ # events = mne.events_from_annotations(epochs
225
+ n_times = epochs.times.shape[0]
226
+ # id_event = {v: k for k, v in epochs.event_id.items()}
227
+ annotations = epochs.annotations
228
+ if annotations is not None:
229
+ target = annotations.description
230
+ else:
231
+ id_events = {v: k for k, v in epochs.event_id.items()}
232
+ target = [id_events[event_id] for event_id in epochs.events[:, -1]]
233
+ metadata_dict = {
234
+ "i_window_in_trial": np.zeros(len(epochs)),
235
+ "i_start_in_trial": np.zeros(len(epochs)),
236
+ "i_stop_in_trial": np.zeros(len(epochs)) + n_times,
237
+ "target": target,
238
+ }
239
+ epochs.metadata = pd.DataFrame(metadata_dict)
240
+
241
+ def _get_dataset(self, bids_path):
242
+ description = _description_from_bids_path(bids_path)
243
+ epochs = mne.read_epochs(bids_path.fpath)
244
+ self._set_metadata(epochs)
245
+ return WindowsDataset(epochs, description=description, targets_from="metadata")
@@ -0,0 +1,172 @@
1
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ from __future__ import annotations
7
+
8
+ import mne
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
13
+
14
+
15
+ def create_from_mne_raw(
16
+ raws: list[mne.io.BaseRaw],
17
+ trial_start_offset_samples: int,
18
+ trial_stop_offset_samples: int,
19
+ window_size_samples: int,
20
+ window_stride_samples: int,
21
+ drop_last_window: bool,
22
+ descriptions: list[dict | pd.Series] | None = None,
23
+ mapping: dict[str, int] | None = None,
24
+ preload: bool = False,
25
+ drop_bad_windows: bool = True,
26
+ accepted_bads_ratio: float = 0.0,
27
+ ) -> BaseConcatDataset:
28
+ """Create WindowsDatasets from mne.RawArrays
29
+
30
+ Parameters
31
+ ----------
32
+ raws: array-like
33
+ list of mne.RawArrays
34
+ trial_start_offset_samples: int
35
+ start offset from original trial onsets in samples
36
+ trial_stop_offset_samples: int
37
+ stop offset from original trial stop in samples
38
+ window_size_samples: int
39
+ window size
40
+ window_stride_samples: int
41
+ stride between windows
42
+ drop_last_window: bool
43
+ whether or not have a last overlapping window, when
44
+ windows do not equally divide the continuous signal
45
+ descriptions: array-like
46
+ list of dicts or pandas.Series with additional information about the raws
47
+ mapping: dict(str: int)
48
+ mapping from event description to target value
49
+ preload: bool
50
+ if True, preload the data of the Epochs objects.
51
+ drop_bad_windows: bool
52
+ If True, call `.drop_bad()` on the resulting mne.Epochs object. This
53
+ step allows identifying e.g., windows that fall outside of the
54
+ continuous recording. It is suggested to run this step here as otherwise
55
+ the BaseConcatDataset has to be updated as well.
56
+ accepted_bads_ratio: float, optional
57
+ Acceptable proportion of trials withinconsistent length in a raw. If
58
+ the number of trials whose length is exceeded by the window size is
59
+ smaller than this, then only the corresponding trials are dropped, but
60
+ the computation continues. Otherwise, an error is raised. Defaults to
61
+ 0.0 (raise an error).
62
+
63
+ Returns
64
+ -------
65
+ windows_datasets: BaseConcatDataset
66
+ X and y transformed to a dataset format that is compatible with skorch
67
+ and braindecode
68
+ """
69
+ # Prevent circular import
70
+ from ..preprocessing.windowers import create_windows_from_events
71
+
72
+ if descriptions is not None:
73
+ if len(descriptions) != len(raws):
74
+ raise ValueError(
75
+ f"length of 'raws' ({len(raws)}) and 'description' "
76
+ f"({len(descriptions)}) has to match"
77
+ )
78
+ base_datasets = [
79
+ BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
80
+ ]
81
+ else:
82
+ base_datasets = [BaseDataset(raw) for raw in raws]
83
+
84
+ base_datasets = BaseConcatDataset(base_datasets)
85
+ windows_datasets = create_windows_from_events(
86
+ base_datasets,
87
+ trial_start_offset_samples=trial_start_offset_samples,
88
+ trial_stop_offset_samples=trial_stop_offset_samples,
89
+ window_size_samples=window_size_samples,
90
+ window_stride_samples=window_stride_samples,
91
+ drop_last_window=drop_last_window,
92
+ mapping=mapping,
93
+ drop_bad_windows=drop_bad_windows,
94
+ preload=preload,
95
+ accepted_bads_ratio=accepted_bads_ratio,
96
+ )
97
+ return windows_datasets
98
+
99
+
100
+ def create_from_mne_epochs(
101
+ list_of_epochs: list[mne.BaseEpochs],
102
+ window_size_samples: int,
103
+ window_stride_samples: int,
104
+ drop_last_window: bool,
105
+ ) -> BaseConcatDataset:
106
+ """Create WindowsDatasets from mne.Epochs
107
+
108
+ Parameters
109
+ ----------
110
+ list_of_epochs: array-like
111
+ list of mne.Epochs
112
+ window_size_samples: int
113
+ window size
114
+ window_stride_samples: int
115
+ stride between windows
116
+ drop_last_window: bool
117
+ whether or not have a last overlapping window, when
118
+ windows do not equally divide the continuous signal
119
+
120
+ Returns
121
+ -------
122
+ windows_datasets: BaseConcatDataset
123
+ X and y transformed to a dataset format that is compatible with skorch
124
+ and braindecode
125
+ """
126
+ # Prevent circular import
127
+ from ..preprocessing.windowers import _check_windowing_arguments
128
+
129
+ _check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
130
+
131
+ list_of_windows_ds = []
132
+ for epochs in list_of_epochs:
133
+ event_descriptions = epochs.events[:, 2]
134
+ original_trial_starts = epochs.events[:, 0]
135
+ stop = len(epochs.times) - window_size_samples
136
+
137
+ # already includes last incomplete window start
138
+ starts = np.arange(0, stop + 1, window_stride_samples)
139
+
140
+ if not drop_last_window and starts[-1] < stop:
141
+ # if last window does not end at trial stop, make it stop there
142
+ starts = np.append(starts, stop)
143
+
144
+ fake_events = [[start, window_size_samples, -1] for start in starts]
145
+
146
+ for trial_i, trial in enumerate(epochs):
147
+ metadata = pd.DataFrame(
148
+ {
149
+ "i_window_in_trial": np.arange(len(fake_events)),
150
+ "i_start_in_trial": starts + original_trial_starts[trial_i],
151
+ "i_stop_in_trial": starts
152
+ + original_trial_starts[trial_i]
153
+ + window_size_samples,
154
+ "target": len(fake_events) * [event_descriptions[trial_i]],
155
+ }
156
+ )
157
+ # window size - 1, since tmax is inclusive
158
+ mne_epochs = mne.Epochs(
159
+ mne.io.RawArray(trial, epochs.info),
160
+ fake_events,
161
+ baseline=None,
162
+ tmin=0,
163
+ tmax=(window_size_samples - 1) / epochs.info["sfreq"],
164
+ metadata=metadata,
165
+ )
166
+
167
+ mne_epochs.drop_bad(reject=None, flat=None)
168
+
169
+ windows_ds = WindowsDataset(mne_epochs)
170
+ list_of_windows_ds.append(windows_ds)
171
+
172
+ return BaseConcatDataset(list_of_windows_ds)