braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
"""Preprocessors that work on Raw or Epochs objects."""
|
|
2
|
+
|
|
3
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Simon Brandt <simonbrandt@protonmail.com>
|
|
6
|
+
# David Sabbagh <dav.sabbagh@gmail.com>
|
|
7
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import platform
|
|
14
|
+
import sys
|
|
15
|
+
from collections.abc import Iterable
|
|
16
|
+
from functools import cached_property, partial
|
|
17
|
+
from importlib import import_module
|
|
18
|
+
from inspect import signature
|
|
19
|
+
from warnings import warn
|
|
20
|
+
|
|
21
|
+
if sys.version_info < (3, 9):
|
|
22
|
+
from typing import Callable
|
|
23
|
+
else:
|
|
24
|
+
from collections.abc import Callable
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import pandas as pd
|
|
28
|
+
from joblib import Parallel, delayed
|
|
29
|
+
from mne import BaseEpochs, create_info
|
|
30
|
+
from mne.io import BaseRaw
|
|
31
|
+
from numpy.typing import NDArray
|
|
32
|
+
|
|
33
|
+
from braindecode.datasets.base import (
|
|
34
|
+
BaseConcatDataset,
|
|
35
|
+
EEGWindowsDataset,
|
|
36
|
+
RawDataset,
|
|
37
|
+
RecordDataset,
|
|
38
|
+
WindowsDataset,
|
|
39
|
+
)
|
|
40
|
+
from braindecode.datautil.serialization import (
|
|
41
|
+
_check_save_dir_empty,
|
|
42
|
+
load_concat_dataset,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Preprocessor(object):
|
|
47
|
+
"""Preprocessor for an MNE Raw or Epochs object.
|
|
48
|
+
|
|
49
|
+
Applies the provided preprocessing function to the data of a Raw or Epochs
|
|
50
|
+
object.
|
|
51
|
+
If the function is provided as a string, the method with that name will be
|
|
52
|
+
used (e.g., 'pick_channels', 'filter', etc.).
|
|
53
|
+
If it is provided as a callable and `apply_on_array` is True, the
|
|
54
|
+
`apply_function` method of Raw and Epochs object will be used to apply the
|
|
55
|
+
function on the internal arrays of Raw and Epochs.
|
|
56
|
+
If `apply_on_array` is False, the callable must directly modify the Raw or
|
|
57
|
+
Epochs object (e.g., by calling its method(s) or modifying its attributes).
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
fn : str or callable
|
|
62
|
+
If str, the Raw/Epochs object must have a method with that name.
|
|
63
|
+
If callable, directly apply the callable to the object.
|
|
64
|
+
apply_on_array : bool
|
|
65
|
+
Ignored if ``fn`` is not a callable. If True, the ``apply_function`` of Raw
|
|
66
|
+
and Epochs will be used to run ``fn`` on the underlying arrays directly.
|
|
67
|
+
If False, ``fn`` must directly modify the Raw or Epochs object.
|
|
68
|
+
**kwargs : dict
|
|
69
|
+
Keyword arguments forwarded to the MNE function or callable.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
|
|
73
|
+
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
|
|
74
|
+
warn("Preprocessing choices with lambda functions cannot be saved.")
|
|
75
|
+
if apply_on_array and not callable(fn):
|
|
76
|
+
warn(
|
|
77
|
+
"apply_on_array can only be True if fn is a callable function. "
|
|
78
|
+
"Automatically correcting to apply_on_array=False."
|
|
79
|
+
)
|
|
80
|
+
apply_on_array = False
|
|
81
|
+
# We store the exact input parameters. Simpler for serialization.
|
|
82
|
+
self.fn = fn
|
|
83
|
+
self.apply_on_array = apply_on_array
|
|
84
|
+
self.kwargs = kwargs
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def _all_attrs(self):
|
|
88
|
+
return ["fn", "apply_on_array", "kwargs"]
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def _init_attrs(self):
|
|
92
|
+
return [k for k in self._all_attrs if k in signature(self.__init__).parameters]
|
|
93
|
+
|
|
94
|
+
@cached_property
|
|
95
|
+
def _function(self):
|
|
96
|
+
kwargs = dict(self.kwargs)
|
|
97
|
+
fn = self.fn
|
|
98
|
+
if self.apply_on_array:
|
|
99
|
+
channel_wise = kwargs.pop("channel_wise", False)
|
|
100
|
+
picks = kwargs.pop("picks", None)
|
|
101
|
+
n_jobs = kwargs.pop("n_jobs", 1)
|
|
102
|
+
kwargs = dict(
|
|
103
|
+
fun=partial(fn, **kwargs),
|
|
104
|
+
channel_wise=channel_wise,
|
|
105
|
+
picks=picks,
|
|
106
|
+
n_jobs=n_jobs,
|
|
107
|
+
)
|
|
108
|
+
fn = "apply_function"
|
|
109
|
+
|
|
110
|
+
if callable(fn):
|
|
111
|
+
return partial(fn, **kwargs)
|
|
112
|
+
return partial(self._apply_str, fn=fn, **kwargs)
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def _apply_str(raw_or_epochs: BaseRaw | BaseEpochs, fn: str, **kwargs):
|
|
116
|
+
if not hasattr(raw_or_epochs, fn):
|
|
117
|
+
raise AttributeError(f"MNE object does not have a {fn} method.")
|
|
118
|
+
return getattr(raw_or_epochs, fn)(**kwargs)
|
|
119
|
+
|
|
120
|
+
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
|
|
121
|
+
function = self._function
|
|
122
|
+
try:
|
|
123
|
+
result = function(raw_or_epochs)
|
|
124
|
+
except RuntimeError:
|
|
125
|
+
# Maybe the function needs the data to be loaded and the data was
|
|
126
|
+
# not loaded yet. Not all MNE functions need data to be loaded,
|
|
127
|
+
# most importantly the 'crop' function can be lazily applied
|
|
128
|
+
# without preloading data which can make the overall preprocessing
|
|
129
|
+
# pipeline substantially faster.
|
|
130
|
+
raw_or_epochs.load_data()
|
|
131
|
+
result = function(raw_or_epochs)
|
|
132
|
+
if result is not None:
|
|
133
|
+
return result
|
|
134
|
+
return raw_or_epochs
|
|
135
|
+
|
|
136
|
+
def serialize(self):
|
|
137
|
+
"""Return a serializable representation of the Preprocessor.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
dict
|
|
142
|
+
Dictionary with keys 'fn' and 'kwargs' representing the
|
|
143
|
+
Preprocessor.
|
|
144
|
+
"""
|
|
145
|
+
out = {k: getattr(self, k) for k in self._init_attrs}
|
|
146
|
+
if "fn" in out and callable(self.fn):
|
|
147
|
+
out["fn"] = self.fn.__module__ + "." + self.fn.__name__
|
|
148
|
+
out["__class_path__"] = (
|
|
149
|
+
self.__class__.__module__ + "." + self.__class__.__name__
|
|
150
|
+
)
|
|
151
|
+
if "kwargs" not in out and self.kwargs:
|
|
152
|
+
out["kwargs"] = self.kwargs
|
|
153
|
+
return out
|
|
154
|
+
|
|
155
|
+
@classmethod
|
|
156
|
+
def deserialize(cls_parent, data: dict):
|
|
157
|
+
"""Create a Preprocessor from its serializable representation.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
data : dict
|
|
162
|
+
Dictionary with keys 'fn' and 'kwargs' representing the
|
|
163
|
+
Preprocessor.
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
Preprocessor
|
|
167
|
+
The deserialized Preprocessor object.
|
|
168
|
+
"""
|
|
169
|
+
class_path = data.pop("__class_path__")
|
|
170
|
+
cls_name = class_path.split(".")[-1]
|
|
171
|
+
cls_module_name = ".".join(class_path.split(".")[:-1])
|
|
172
|
+
cls_module = import_module(cls_module_name)
|
|
173
|
+
cls = getattr(cls_module, cls_name)
|
|
174
|
+
|
|
175
|
+
kwargs = data.pop("kwargs") if "kwargs" in data else {}
|
|
176
|
+
|
|
177
|
+
fn = data.get("fn", None)
|
|
178
|
+
if fn is not None and "." in fn: # callable function
|
|
179
|
+
fn_name = fn.split(".")[-1]
|
|
180
|
+
module_name = ".".join(fn.split(".")[:-1])
|
|
181
|
+
module = import_module(module_name)
|
|
182
|
+
data["fn"] = getattr(module, fn_name)
|
|
183
|
+
|
|
184
|
+
return cls(**data, **kwargs)
|
|
185
|
+
|
|
186
|
+
def __repr__(self):
|
|
187
|
+
cls_name = self.__class__.__name__
|
|
188
|
+
args_str = ", ".join(
|
|
189
|
+
f"{k}={getattr(self, k).__repr__()}" for k in self._init_attrs
|
|
190
|
+
)
|
|
191
|
+
return f"{cls_name}({args_str})"
|
|
192
|
+
|
|
193
|
+
def _same_attr(self, other, attr):
|
|
194
|
+
a = getattr(self, attr)
|
|
195
|
+
b = getattr(other, attr)
|
|
196
|
+
if attr == "fn" and callable(a):
|
|
197
|
+
return a.__module__ == b.__module__ and a.__name__ == b.__name__
|
|
198
|
+
if isinstance(a, np.ndarray):
|
|
199
|
+
return np.array_equal(a, b)
|
|
200
|
+
return a == b
|
|
201
|
+
|
|
202
|
+
def __eq__(self, other):
|
|
203
|
+
if not isinstance(other, Preprocessor):
|
|
204
|
+
return False
|
|
205
|
+
return all(self._same_attr(other, attr) for attr in self._all_attrs) and (
|
|
206
|
+
self.__class__ == other.__class__
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def preprocess(
|
|
211
|
+
concat_ds: BaseConcatDataset,
|
|
212
|
+
preprocessors: list[Preprocessor],
|
|
213
|
+
save_dir: str | None = None,
|
|
214
|
+
overwrite: bool = False,
|
|
215
|
+
n_jobs: int | None = None,
|
|
216
|
+
offset: int = 0,
|
|
217
|
+
copy_data: bool | None = None,
|
|
218
|
+
parallel_kwargs: dict | None = None,
|
|
219
|
+
):
|
|
220
|
+
"""Apply preprocessors to a concat dataset.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
concat_ds : BaseConcatDataset
|
|
225
|
+
A concat of ``RecordDataset`` to be preprocessed.
|
|
226
|
+
preprocessors : list of Preprocessor
|
|
227
|
+
Preprocessor objects to apply to each dataset.
|
|
228
|
+
save_dir : str | None
|
|
229
|
+
If provided, save preprocessed data under this directory and reload
|
|
230
|
+
datasets in ``concat_ds`` with ``preload=False``.
|
|
231
|
+
overwrite : bool
|
|
232
|
+
When ``save_dir`` is provided, controls whether to delete the old
|
|
233
|
+
subdirectories that will be written to under ``save_dir``. If False and
|
|
234
|
+
the corresponding subdirectories already exist, a ``FileExistsError`` is raised.
|
|
235
|
+
n_jobs : int | None
|
|
236
|
+
Number of jobs for parallel execution. See ``joblib.Parallel`` for details.
|
|
237
|
+
offset : int
|
|
238
|
+
Integer added to the dataset id in the concat. Useful when processing
|
|
239
|
+
and saving very large datasets in chunks to preserve original positions.
|
|
240
|
+
copy_data : bool | None
|
|
241
|
+
Whether the data passed to parallel jobs should be copied or passed by reference.
|
|
242
|
+
parallel_kwargs : dict | None
|
|
243
|
+
Additional keyword arguments forwarded to ``joblib.Parallel``.
|
|
244
|
+
Defaults to None (equivalent to ``{}``).
|
|
245
|
+
See https://joblib.readthedocs.io/en/stable/generated/joblib.Parallel.html for details.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
BaseConcatDataset
|
|
250
|
+
Preprocessed dataset.
|
|
251
|
+
"""
|
|
252
|
+
# In case of serialization, make sure directory is available before
|
|
253
|
+
# preprocessing
|
|
254
|
+
if save_dir is not None and not overwrite:
|
|
255
|
+
_check_save_dir_empty(save_dir)
|
|
256
|
+
|
|
257
|
+
if not isinstance(preprocessors, Iterable):
|
|
258
|
+
raise ValueError("preprocessors must be a list of Preprocessor objects.")
|
|
259
|
+
for elem in preprocessors:
|
|
260
|
+
assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
|
|
261
|
+
|
|
262
|
+
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
|
|
263
|
+
|
|
264
|
+
parallel_params = {} if parallel_kwargs is None else dict(parallel_kwargs)
|
|
265
|
+
parallel_params.setdefault(
|
|
266
|
+
"prefer", "threads" if platform.system() == "Windows" else None
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
list_of_ds = Parallel(n_jobs=n_jobs, **parallel_params)(
|
|
270
|
+
delayed(_preprocess)(
|
|
271
|
+
ds,
|
|
272
|
+
i + offset,
|
|
273
|
+
preprocessors,
|
|
274
|
+
save_dir,
|
|
275
|
+
overwrite,
|
|
276
|
+
copy_data=(
|
|
277
|
+
(parallel_processing and (save_dir is None))
|
|
278
|
+
if copy_data is None
|
|
279
|
+
else copy_data
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
for i, ds in enumerate(concat_ds.datasets)
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if save_dir is not None: # Reload datasets and replace in concat_ds
|
|
286
|
+
ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
|
|
287
|
+
concat_ds_reloaded = load_concat_dataset(
|
|
288
|
+
save_dir,
|
|
289
|
+
preload=False,
|
|
290
|
+
target_name=None,
|
|
291
|
+
ids_to_load=ids_to_load,
|
|
292
|
+
)
|
|
293
|
+
_replace_inplace(concat_ds, concat_ds_reloaded)
|
|
294
|
+
else:
|
|
295
|
+
if parallel_processing: # joblib made copies
|
|
296
|
+
_replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
|
|
297
|
+
else: # joblib did not make copies, the
|
|
298
|
+
# preprocessing happened in-place
|
|
299
|
+
# Recompute cumulative sizes as transforms might have changed them
|
|
300
|
+
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
|
|
301
|
+
|
|
302
|
+
return concat_ds
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _replace_inplace(concat_ds, new_concat_ds):
|
|
306
|
+
"""Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
concat_ds : BaseConcatDataset
|
|
311
|
+
Dataset to modify inplace.
|
|
312
|
+
new_concat_ds : BaseConcatDataset
|
|
313
|
+
Dataset to use to modify ``concat_ds``.
|
|
314
|
+
"""
|
|
315
|
+
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
|
|
316
|
+
raise ValueError("Both inputs must have the same length.")
|
|
317
|
+
for i in range(len(new_concat_ds.datasets)):
|
|
318
|
+
concat_ds.datasets[i] = new_concat_ds.datasets[i]
|
|
319
|
+
|
|
320
|
+
concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
|
|
321
|
+
preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
|
|
322
|
+
if hasattr(new_concat_ds, preproc_kwargs_attr):
|
|
323
|
+
setattr(
|
|
324
|
+
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Recompute cumulative_sizes after replacing datasets
|
|
328
|
+
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _preprocess(
|
|
332
|
+
ds: RecordDataset,
|
|
333
|
+
ds_index,
|
|
334
|
+
preprocessors,
|
|
335
|
+
save_dir=None,
|
|
336
|
+
overwrite=False,
|
|
337
|
+
copy_data=False,
|
|
338
|
+
):
|
|
339
|
+
"""Apply preprocessor(s) to Raw or Epochs object.
|
|
340
|
+
|
|
341
|
+
Parameters
|
|
342
|
+
----------
|
|
343
|
+
ds: RecordDataset
|
|
344
|
+
Dataset object to preprocess.
|
|
345
|
+
ds_index : int
|
|
346
|
+
Index of the ``RecordDataset`` in its ``BaseConcatDataset``. Ignored if save_dir
|
|
347
|
+
is None.
|
|
348
|
+
preprocessors: list(Preprocessor)
|
|
349
|
+
List of preprocessors to apply to the dataset.
|
|
350
|
+
save_dir : str | None
|
|
351
|
+
If provided, save the preprocessed RecordDataset in the
|
|
352
|
+
specified directory.
|
|
353
|
+
overwrite : bool
|
|
354
|
+
If True, overwrite existing file with the same name.
|
|
355
|
+
copy_data : bool
|
|
356
|
+
First copy the data in case it is preloaded. Necessary for parallel processing to work.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
|
|
360
|
+
# Copying the data necessary in some scenarios for parallel processing
|
|
361
|
+
# to work when data is in memory (else error about _data not being writeable)
|
|
362
|
+
if raw_or_epochs.preload and copy_data:
|
|
363
|
+
raw_or_epochs._data = raw_or_epochs._data.copy()
|
|
364
|
+
for preproc in preprocessors:
|
|
365
|
+
raw_or_epochs = preproc.apply(raw_or_epochs)
|
|
366
|
+
return raw_or_epochs
|
|
367
|
+
|
|
368
|
+
if hasattr(ds, "raw"):
|
|
369
|
+
if isinstance(ds, EEGWindowsDataset):
|
|
370
|
+
warn(
|
|
371
|
+
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
|
|
372
|
+
)
|
|
373
|
+
processed = _preprocess_raw_or_epochs(ds.raw, preprocessors)
|
|
374
|
+
if processed is not ds.raw:
|
|
375
|
+
ds.raw = processed
|
|
376
|
+
elif hasattr(ds, "windows"):
|
|
377
|
+
processed = _preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
378
|
+
if processed is not ds.windows:
|
|
379
|
+
ds.windows = processed
|
|
380
|
+
else:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
"Can only preprocess concatenation of RecordDataset, "
|
|
383
|
+
"with either a `raw` or `windows` attribute."
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Store preprocessing keyword arguments in the dataset
|
|
387
|
+
_set_preproc_kwargs(ds, preprocessors)
|
|
388
|
+
|
|
389
|
+
if save_dir is not None:
|
|
390
|
+
concat_ds = BaseConcatDataset([ds])
|
|
391
|
+
concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
|
|
392
|
+
else:
|
|
393
|
+
return ds
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _set_preproc_kwargs(ds, preprocessors):
|
|
397
|
+
"""Record preprocessing keyword arguments in RecordDataset.
|
|
398
|
+
|
|
399
|
+
Parameters
|
|
400
|
+
----------
|
|
401
|
+
ds : RecordDataset
|
|
402
|
+
Dataset in which to record preprocessing keyword arguments.
|
|
403
|
+
preprocessors : list
|
|
404
|
+
List of preprocessors.
|
|
405
|
+
"""
|
|
406
|
+
preproc_kwargs = [p.serialize() for p in preprocessors]
|
|
407
|
+
if isinstance(ds, WindowsDataset):
|
|
408
|
+
kind = "window"
|
|
409
|
+
elif isinstance(ds, EEGWindowsDataset):
|
|
410
|
+
kind = "raw"
|
|
411
|
+
elif isinstance(ds, RawDataset):
|
|
412
|
+
kind = "raw"
|
|
413
|
+
else:
|
|
414
|
+
raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
|
|
415
|
+
old_preproc_kwargs = getattr(ds, kind + "_preproc_kwargs")
|
|
416
|
+
old_preproc_kwargs.extend(preproc_kwargs)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def exponential_moving_standardize(
|
|
420
|
+
data: NDArray,
|
|
421
|
+
factor_new: float = 0.001,
|
|
422
|
+
init_block_size: int | None = None,
|
|
423
|
+
eps: float = 1e-4,
|
|
424
|
+
):
|
|
425
|
+
r"""Perform exponential moving standardization.
|
|
426
|
+
|
|
427
|
+
Compute the exponental moving mean :math:`m_t` at time `t` as
|
|
428
|
+
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
|
|
429
|
+
|
|
430
|
+
Then, compute exponential moving variance :math:`v_t` at time `t` as
|
|
431
|
+
:math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
|
|
432
|
+
|
|
433
|
+
Finally, standardize the data point :math:`x_t` at time `t` as:
|
|
434
|
+
:math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
Parameters
|
|
438
|
+
----------
|
|
439
|
+
data: np.ndarray (n_channels, n_times)
|
|
440
|
+
factor_new: float
|
|
441
|
+
init_block_size: int
|
|
442
|
+
Standardize data before to this index with regular standardization.
|
|
443
|
+
eps: float
|
|
444
|
+
Stabilizer for division by zero variance.
|
|
445
|
+
|
|
446
|
+
Returns
|
|
447
|
+
-------
|
|
448
|
+
standardized: np.ndarray (n_channels, n_times)
|
|
449
|
+
Standardized data.
|
|
450
|
+
"""
|
|
451
|
+
data = data.T
|
|
452
|
+
df = pd.DataFrame(data)
|
|
453
|
+
meaned = df.ewm(alpha=factor_new).mean()
|
|
454
|
+
demeaned = df - meaned
|
|
455
|
+
squared = demeaned * demeaned
|
|
456
|
+
square_ewmed = squared.ewm(alpha=factor_new).mean()
|
|
457
|
+
standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
|
|
458
|
+
standardized = np.array(standardized)
|
|
459
|
+
if init_block_size is not None:
|
|
460
|
+
i_time_axis = 0
|
|
461
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
462
|
+
init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
463
|
+
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
|
|
464
|
+
eps, init_std
|
|
465
|
+
)
|
|
466
|
+
standardized[0:init_block_size] = init_block_standardized
|
|
467
|
+
return standardized.T
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def exponential_moving_demean(
|
|
471
|
+
data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
|
|
472
|
+
):
|
|
473
|
+
r"""Perform exponential moving demeanining.
|
|
474
|
+
|
|
475
|
+
Compute the exponental moving mean :math:`m_t` at time `t` as
|
|
476
|
+
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
|
|
477
|
+
|
|
478
|
+
Deman the data point :math:`x_t` at time `t` as:
|
|
479
|
+
:math:`x'_t=(x_t - m_t)`.
|
|
480
|
+
|
|
481
|
+
Parameters
|
|
482
|
+
----------
|
|
483
|
+
data: np.ndarray (n_channels, n_times)
|
|
484
|
+
factor_new: float
|
|
485
|
+
init_block_size: int
|
|
486
|
+
Demean data before to this index with regular demeaning.
|
|
487
|
+
|
|
488
|
+
Returns
|
|
489
|
+
-------
|
|
490
|
+
demeaned: np.ndarray (n_channels, n_times)
|
|
491
|
+
Demeaned data.
|
|
492
|
+
"""
|
|
493
|
+
data = data.T
|
|
494
|
+
df = pd.DataFrame(data)
|
|
495
|
+
meaned = df.ewm(alpha=factor_new).mean()
|
|
496
|
+
demeaned = df - meaned
|
|
497
|
+
demeaned = np.array(demeaned)
|
|
498
|
+
if init_block_size is not None:
|
|
499
|
+
i_time_axis = 0
|
|
500
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
501
|
+
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
|
|
502
|
+
return demeaned.T
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def filterbank(
|
|
506
|
+
raw: BaseRaw,
|
|
507
|
+
frequency_bands: list[tuple[float, float]],
|
|
508
|
+
drop_original_signals: bool = True,
|
|
509
|
+
order_by_frequency_band: bool = False,
|
|
510
|
+
**mne_filter_kwargs,
|
|
511
|
+
):
|
|
512
|
+
"""Applies multiple bandpass filters to the signals in raw. The raw will be
|
|
513
|
+
modified in-place and number of channels in raw will be updated to
|
|
514
|
+
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
|
|
515
|
+
drop_original_signals).
|
|
516
|
+
|
|
517
|
+
Parameters
|
|
518
|
+
----------
|
|
519
|
+
raw: mne.io.Raw
|
|
520
|
+
The raw signals to be filtered.
|
|
521
|
+
frequency_bands: list(tuple)
|
|
522
|
+
The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
|
|
523
|
+
drop_original_signals: bool
|
|
524
|
+
Whether to drop the original unfiltered signals
|
|
525
|
+
order_by_frequency_band: bool
|
|
526
|
+
If True will return channels ordered by frequency bands, so if there
|
|
527
|
+
are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
|
|
528
|
+
channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
|
|
529
|
+
will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
|
|
530
|
+
mne_filter_kwargs: dict
|
|
531
|
+
Keyword arguments for filtering supported by mne.io.Raw.filter().
|
|
532
|
+
Please refer to mne for a detailed explanation.
|
|
533
|
+
"""
|
|
534
|
+
if not frequency_bands:
|
|
535
|
+
raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
|
|
536
|
+
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
|
|
537
|
+
warn(
|
|
538
|
+
"Try to use shorter channel names, since frequency band "
|
|
539
|
+
"annotation requires an estimated 4-8 chars depending on the "
|
|
540
|
+
"frequency ranges. Will truncate to 15 chars (mne max)."
|
|
541
|
+
)
|
|
542
|
+
original_ch_names = raw.ch_names
|
|
543
|
+
all_filtered = []
|
|
544
|
+
for l_freq, h_freq in frequency_bands:
|
|
545
|
+
filtered = raw.copy()
|
|
546
|
+
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
|
|
547
|
+
# mne automatically changes the highpass/lowpass info values
|
|
548
|
+
# when applying filters and channels can't be added if they have
|
|
549
|
+
# different such parameters. Not needed when making picks as
|
|
550
|
+
# high pass is not modified by filter if pick is specified
|
|
551
|
+
|
|
552
|
+
ch_names = filtered.info.ch_names
|
|
553
|
+
ch_types = filtered.info.get_channel_types()
|
|
554
|
+
sampling_freq = filtered.info["sfreq"]
|
|
555
|
+
|
|
556
|
+
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
|
|
557
|
+
|
|
558
|
+
filtered.info = info
|
|
559
|
+
|
|
560
|
+
# add frequency band annotation to channel names
|
|
561
|
+
# truncate to a max of 15 characters, since mne does not allow for more
|
|
562
|
+
filtered.rename_channels(
|
|
563
|
+
{
|
|
564
|
+
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
|
|
565
|
+
for old_name in filtered.ch_names
|
|
566
|
+
}
|
|
567
|
+
)
|
|
568
|
+
all_filtered.append(filtered)
|
|
569
|
+
raw.add_channels(all_filtered)
|
|
570
|
+
if not order_by_frequency_band:
|
|
571
|
+
# order channels by name and not by frequency band:
|
|
572
|
+
# index the list with a stepsize of the number of channels for each of
|
|
573
|
+
# the original channels
|
|
574
|
+
chs_by_freq_band = []
|
|
575
|
+
for i in range(len(original_ch_names)):
|
|
576
|
+
chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
|
|
577
|
+
raw.reorder_channels(chs_by_freq_band)
|
|
578
|
+
if drop_original_signals:
|
|
579
|
+
raw.drop_channels(original_ch_names)
|