braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
"""Preprocessors that work on Raw or Epochs objects."""
|
|
2
|
+
|
|
3
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Simon Brandt <simonbrandt@protonmail.com>
|
|
6
|
+
# David Sabbagh <dav.sabbagh@gmail.com>
|
|
7
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import platform
|
|
14
|
+
import sys
|
|
15
|
+
from collections.abc import Iterable
|
|
16
|
+
from functools import partial
|
|
17
|
+
from warnings import warn
|
|
18
|
+
|
|
19
|
+
if sys.version_info < (3, 9):
|
|
20
|
+
from typing import Callable
|
|
21
|
+
else:
|
|
22
|
+
from collections.abc import Callable
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from joblib import Parallel, delayed
|
|
27
|
+
from mne import BaseEpochs, create_info
|
|
28
|
+
from mne.io import BaseRaw
|
|
29
|
+
from numpy.typing import NDArray
|
|
30
|
+
|
|
31
|
+
from braindecode.datasets.base import (
|
|
32
|
+
BaseConcatDataset,
|
|
33
|
+
BaseDataset,
|
|
34
|
+
EEGWindowsDataset,
|
|
35
|
+
WindowsDataset,
|
|
36
|
+
)
|
|
37
|
+
from braindecode.datautil.serialization import (
|
|
38
|
+
_check_save_dir_empty,
|
|
39
|
+
load_concat_dataset,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Preprocessor(object):
|
|
44
|
+
"""Preprocessor for an MNE Raw or Epochs object.
|
|
45
|
+
|
|
46
|
+
Applies the provided preprocessing function to the data of a Raw or Epochs
|
|
47
|
+
object.
|
|
48
|
+
If the function is provided as a string, the method with that name will be
|
|
49
|
+
used (e.g., 'pick_channels', 'filter', etc.).
|
|
50
|
+
If it is provided as a callable and `apply_on_array` is True, the
|
|
51
|
+
`apply_function` method of Raw and Epochs object will be used to apply the
|
|
52
|
+
function on the internal arrays of Raw and Epochs.
|
|
53
|
+
If `apply_on_array` is False, the callable must directly modify the Raw or
|
|
54
|
+
Epochs object (e.g., by calling its method(s) or modifying its attributes).
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
fn: str or callable
|
|
59
|
+
If str, the Raw/Epochs object must have a method with that name.
|
|
60
|
+
If callable, directly apply the callable to the object.
|
|
61
|
+
apply_on_array : bool
|
|
62
|
+
Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
|
|
63
|
+
and Epochs object will be used to run `fn` on the underlying arrays
|
|
64
|
+
directly. If False, `fn` must directly modify the Raw or Epochs object.
|
|
65
|
+
kwargs:
|
|
66
|
+
Keyword arguments to be forwarded to the MNE function.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
|
|
70
|
+
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
|
|
71
|
+
warn("Preprocessing choices with lambda functions cannot be saved.")
|
|
72
|
+
if callable(fn) and apply_on_array:
|
|
73
|
+
channel_wise = kwargs.pop("channel_wise", False)
|
|
74
|
+
picks = kwargs.pop("picks", None)
|
|
75
|
+
n_jobs = kwargs.pop("n_jobs", 1)
|
|
76
|
+
kwargs = dict(
|
|
77
|
+
fun=partial(fn, **kwargs),
|
|
78
|
+
channel_wise=channel_wise,
|
|
79
|
+
picks=picks,
|
|
80
|
+
n_jobs=n_jobs,
|
|
81
|
+
)
|
|
82
|
+
fn = "apply_function"
|
|
83
|
+
self.fn = fn
|
|
84
|
+
self.kwargs = kwargs
|
|
85
|
+
|
|
86
|
+
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
|
|
87
|
+
try:
|
|
88
|
+
self._try_apply(raw_or_epochs)
|
|
89
|
+
except RuntimeError:
|
|
90
|
+
# Maybe the function needs the data to be loaded and the data was
|
|
91
|
+
# not loaded yet. Not all MNE functions need data to be loaded,
|
|
92
|
+
# most importantly the 'crop' function can be lazily applied
|
|
93
|
+
# without preloading data which can make the overall preprocessing
|
|
94
|
+
# pipeline substantially faster.
|
|
95
|
+
raw_or_epochs.load_data()
|
|
96
|
+
self._try_apply(raw_or_epochs)
|
|
97
|
+
|
|
98
|
+
def _try_apply(self, raw_or_epochs):
|
|
99
|
+
if callable(self.fn):
|
|
100
|
+
self.fn(raw_or_epochs, **self.kwargs)
|
|
101
|
+
else:
|
|
102
|
+
if not hasattr(raw_or_epochs, self.fn):
|
|
103
|
+
raise AttributeError(f"MNE object does not have a {self.fn} method.")
|
|
104
|
+
getattr(raw_or_epochs, self.fn)(**self.kwargs)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def preprocess(
|
|
108
|
+
concat_ds: BaseConcatDataset,
|
|
109
|
+
preprocessors: list[Preprocessor],
|
|
110
|
+
save_dir: str | None = None,
|
|
111
|
+
overwrite: bool = False,
|
|
112
|
+
n_jobs: int | None = None,
|
|
113
|
+
offset: int = 0,
|
|
114
|
+
copy_data: bool | None = None,
|
|
115
|
+
):
|
|
116
|
+
"""Apply preprocessors to a concat dataset.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
concat_ds: BaseConcatDataset
|
|
121
|
+
A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
|
|
122
|
+
preprocessors: list(Preprocessor)
|
|
123
|
+
List of Preprocessor objects to apply to the dataset.
|
|
124
|
+
save_dir : str | None
|
|
125
|
+
If a string, the preprocessed data will be saved under the specified
|
|
126
|
+
directory and the datasets in ``concat_ds`` will be reloaded with
|
|
127
|
+
`preload=False`.
|
|
128
|
+
overwrite : bool
|
|
129
|
+
When `save_dir` is provided, controls whether to delete the old
|
|
130
|
+
subdirectories that will be written to under `save_dir`. If False and
|
|
131
|
+
the corresponding subdirectories already exist, a ``FileExistsError``
|
|
132
|
+
will be raised.
|
|
133
|
+
n_jobs : int | None
|
|
134
|
+
Number of jobs for parallel execution. See `joblib.Parallel` for
|
|
135
|
+
a more detailed explanation.
|
|
136
|
+
offset : int
|
|
137
|
+
If provided, the integer is added to the id of the dataset in the
|
|
138
|
+
concat. This is useful in the setting of very large datasets, where
|
|
139
|
+
one dataset has to be processed and saved at a time to account for
|
|
140
|
+
its original position.
|
|
141
|
+
copy_data : bool | None
|
|
142
|
+
Whether the data passed to the different jobs should be copied or
|
|
143
|
+
passed by reference.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
BaseConcatDataset:
|
|
148
|
+
Preprocessed dataset.
|
|
149
|
+
"""
|
|
150
|
+
# In case of serialization, make sure directory is available before
|
|
151
|
+
# preprocessing
|
|
152
|
+
if save_dir is not None and not overwrite:
|
|
153
|
+
_check_save_dir_empty(save_dir)
|
|
154
|
+
|
|
155
|
+
if not isinstance(preprocessors, Iterable):
|
|
156
|
+
raise ValueError("preprocessors must be a list of Preprocessor objects.")
|
|
157
|
+
for elem in preprocessors:
|
|
158
|
+
assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
|
|
159
|
+
|
|
160
|
+
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
|
|
161
|
+
|
|
162
|
+
job_prefer = "threads" if platform.system() == "Windows" else None
|
|
163
|
+
list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
|
|
164
|
+
delayed(_preprocess)(
|
|
165
|
+
ds,
|
|
166
|
+
i + offset,
|
|
167
|
+
preprocessors,
|
|
168
|
+
save_dir,
|
|
169
|
+
overwrite,
|
|
170
|
+
copy_data=(
|
|
171
|
+
(parallel_processing and (save_dir is None))
|
|
172
|
+
if copy_data is None
|
|
173
|
+
else copy_data
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
for i, ds in enumerate(concat_ds.datasets)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if save_dir is not None: # Reload datasets and replace in concat_ds
|
|
180
|
+
ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
|
|
181
|
+
concat_ds_reloaded = load_concat_dataset(
|
|
182
|
+
save_dir,
|
|
183
|
+
preload=False,
|
|
184
|
+
target_name=None,
|
|
185
|
+
ids_to_load=ids_to_load,
|
|
186
|
+
)
|
|
187
|
+
_replace_inplace(concat_ds, concat_ds_reloaded)
|
|
188
|
+
else:
|
|
189
|
+
if parallel_processing: # joblib made copies
|
|
190
|
+
_replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
|
|
191
|
+
else: # joblib did not make copies, the
|
|
192
|
+
# preprocessing happened in-place
|
|
193
|
+
# Recompute cumulative sizes as transforms might have changed them
|
|
194
|
+
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
|
|
195
|
+
|
|
196
|
+
return concat_ds
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _replace_inplace(concat_ds, new_concat_ds):
|
|
200
|
+
"""Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
concat_ds : BaseConcatDataset
|
|
205
|
+
Dataset to modify inplace.
|
|
206
|
+
new_concat_ds : BaseConcatDataset
|
|
207
|
+
Dataset to use to modify ``concat_ds``.
|
|
208
|
+
"""
|
|
209
|
+
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
|
|
210
|
+
raise ValueError("Both inputs must have the same length.")
|
|
211
|
+
for i in range(len(new_concat_ds.datasets)):
|
|
212
|
+
concat_ds.datasets[i] = new_concat_ds.datasets[i]
|
|
213
|
+
|
|
214
|
+
concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
|
|
215
|
+
preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
|
|
216
|
+
if hasattr(new_concat_ds, preproc_kwargs_attr):
|
|
217
|
+
setattr(
|
|
218
|
+
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _preprocess(
|
|
223
|
+
ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
|
|
224
|
+
):
|
|
225
|
+
"""Apply preprocessor(s) to Raw or Epochs object.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
ds: BaseDataset | WindowsDataset
|
|
230
|
+
Dataset object to preprocess.
|
|
231
|
+
ds_index : int
|
|
232
|
+
Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
|
|
233
|
+
is None.
|
|
234
|
+
preprocessors: list(Preprocessor)
|
|
235
|
+
List of preprocessors to apply to the dataset.
|
|
236
|
+
save_dir : str | None
|
|
237
|
+
If provided, save the preprocessed BaseDataset in the
|
|
238
|
+
specified directory.
|
|
239
|
+
overwrite : bool
|
|
240
|
+
If True, overwrite existing file with the same name.
|
|
241
|
+
copy_data : bool
|
|
242
|
+
First copy the data in case it is preloaded. Necessary for parallel processing to work.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
|
|
246
|
+
# Copying the data necessary in some scenarios for parallel processing
|
|
247
|
+
# to work when data is in memory (else error about _data not being writeable)
|
|
248
|
+
if raw_or_epochs.preload and copy_data:
|
|
249
|
+
raw_or_epochs._data = raw_or_epochs._data.copy()
|
|
250
|
+
for preproc in preprocessors:
|
|
251
|
+
preproc.apply(raw_or_epochs)
|
|
252
|
+
|
|
253
|
+
if hasattr(ds, "raw"):
|
|
254
|
+
if isinstance(ds, EEGWindowsDataset):
|
|
255
|
+
warn(
|
|
256
|
+
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
|
|
257
|
+
)
|
|
258
|
+
_preprocess_raw_or_epochs(ds.raw, preprocessors)
|
|
259
|
+
elif hasattr(ds, "windows"):
|
|
260
|
+
_preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
261
|
+
else:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Can only preprocess concatenation of BaseDataset or "
|
|
264
|
+
"WindowsDataset, with either a `raw` or `windows` attribute."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Store preprocessing keyword arguments in the dataset
|
|
268
|
+
_set_preproc_kwargs(ds, preprocessors)
|
|
269
|
+
|
|
270
|
+
if save_dir is not None:
|
|
271
|
+
concat_ds = BaseConcatDataset([ds])
|
|
272
|
+
concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
|
|
273
|
+
else:
|
|
274
|
+
return ds
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _get_preproc_kwargs(preprocessors):
|
|
278
|
+
preproc_kwargs = []
|
|
279
|
+
for p in preprocessors:
|
|
280
|
+
# in case of a mne function, fn is a str, kwargs is a dict
|
|
281
|
+
func_name = p.fn
|
|
282
|
+
func_kwargs = p.kwargs
|
|
283
|
+
# in case of another function
|
|
284
|
+
# if apply_on_array=False
|
|
285
|
+
if callable(p.fn):
|
|
286
|
+
func_name = p.fn.__name__
|
|
287
|
+
# if apply_on_array=True
|
|
288
|
+
else:
|
|
289
|
+
if "fun" in p.fn:
|
|
290
|
+
func_name = p.kwargs["fun"].func.__name__
|
|
291
|
+
func_kwargs = p.kwargs["fun"].keywords
|
|
292
|
+
preproc_kwargs.append((func_name, func_kwargs))
|
|
293
|
+
return preproc_kwargs
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _set_preproc_kwargs(ds, preprocessors):
|
|
297
|
+
"""Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
ds : BaseDataset | WindowsDataset
|
|
302
|
+
Dataset in which to record preprocessing keyword arguments.
|
|
303
|
+
preprocessors : list
|
|
304
|
+
List of preprocessors.
|
|
305
|
+
"""
|
|
306
|
+
preproc_kwargs = _get_preproc_kwargs(preprocessors)
|
|
307
|
+
if isinstance(ds, WindowsDataset):
|
|
308
|
+
kind = "window"
|
|
309
|
+
if isinstance(ds, EEGWindowsDataset):
|
|
310
|
+
kind = "raw"
|
|
311
|
+
elif isinstance(ds, BaseDataset):
|
|
312
|
+
kind = "raw"
|
|
313
|
+
else:
|
|
314
|
+
raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
|
|
315
|
+
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def exponential_moving_standardize(
|
|
319
|
+
data: NDArray,
|
|
320
|
+
factor_new: float = 0.001,
|
|
321
|
+
init_block_size: int | None = None,
|
|
322
|
+
eps: float = 1e-4,
|
|
323
|
+
):
|
|
324
|
+
r"""Perform exponential moving standardization.
|
|
325
|
+
|
|
326
|
+
Compute the exponental moving mean :math:`m_t` at time `t` as
|
|
327
|
+
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
|
|
328
|
+
|
|
329
|
+
Then, compute exponential moving variance :math:`v_t` at time `t` as
|
|
330
|
+
:math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
|
|
331
|
+
|
|
332
|
+
Finally, standardize the data point :math:`x_t` at time `t` as:
|
|
333
|
+
:math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
data: np.ndarray (n_channels, n_times)
|
|
339
|
+
factor_new: float
|
|
340
|
+
init_block_size: int
|
|
341
|
+
Standardize data before to this index with regular standardization.
|
|
342
|
+
eps: float
|
|
343
|
+
Stabilizer for division by zero variance.
|
|
344
|
+
|
|
345
|
+
Returns
|
|
346
|
+
-------
|
|
347
|
+
standardized: np.ndarray (n_channels, n_times)
|
|
348
|
+
Standardized data.
|
|
349
|
+
"""
|
|
350
|
+
data = data.T
|
|
351
|
+
df = pd.DataFrame(data)
|
|
352
|
+
meaned = df.ewm(alpha=factor_new).mean()
|
|
353
|
+
demeaned = df - meaned
|
|
354
|
+
squared = demeaned * demeaned
|
|
355
|
+
square_ewmed = squared.ewm(alpha=factor_new).mean()
|
|
356
|
+
standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
|
|
357
|
+
standardized = np.array(standardized)
|
|
358
|
+
if init_block_size is not None:
|
|
359
|
+
i_time_axis = 0
|
|
360
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
361
|
+
init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
362
|
+
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
|
|
363
|
+
eps, init_std
|
|
364
|
+
)
|
|
365
|
+
standardized[0:init_block_size] = init_block_standardized
|
|
366
|
+
return standardized.T
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def exponential_moving_demean(
|
|
370
|
+
data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
|
|
371
|
+
):
|
|
372
|
+
r"""Perform exponential moving demeanining.
|
|
373
|
+
|
|
374
|
+
Compute the exponental moving mean :math:`m_t` at time `t` as
|
|
375
|
+
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
|
|
376
|
+
|
|
377
|
+
Deman the data point :math:`x_t` at time `t` as:
|
|
378
|
+
:math:`x'_t=(x_t - m_t)`.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
data: np.ndarray (n_channels, n_times)
|
|
383
|
+
factor_new: float
|
|
384
|
+
init_block_size: int
|
|
385
|
+
Demean data before to this index with regular demeaning.
|
|
386
|
+
|
|
387
|
+
Returns
|
|
388
|
+
-------
|
|
389
|
+
demeaned: np.ndarray (n_channels, n_times)
|
|
390
|
+
Demeaned data.
|
|
391
|
+
"""
|
|
392
|
+
data = data.T
|
|
393
|
+
df = pd.DataFrame(data)
|
|
394
|
+
meaned = df.ewm(alpha=factor_new).mean()
|
|
395
|
+
demeaned = df - meaned
|
|
396
|
+
demeaned = np.array(demeaned)
|
|
397
|
+
if init_block_size is not None:
|
|
398
|
+
i_time_axis = 0
|
|
399
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
400
|
+
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
|
|
401
|
+
return demeaned.T
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def filterbank(
|
|
405
|
+
raw: BaseRaw,
|
|
406
|
+
frequency_bands: list[tuple[float, float]],
|
|
407
|
+
drop_original_signals: bool = True,
|
|
408
|
+
order_by_frequency_band: bool = False,
|
|
409
|
+
**mne_filter_kwargs,
|
|
410
|
+
):
|
|
411
|
+
"""Applies multiple bandpass filters to the signals in raw. The raw will be
|
|
412
|
+
modified in-place and number of channels in raw will be updated to
|
|
413
|
+
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
|
|
414
|
+
drop_original_signals).
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
raw: mne.io.Raw
|
|
419
|
+
The raw signals to be filtered.
|
|
420
|
+
frequency_bands: list(tuple)
|
|
421
|
+
The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
|
|
422
|
+
drop_original_signals: bool
|
|
423
|
+
Whether to drop the original unfiltered signals
|
|
424
|
+
order_by_frequency_band: bool
|
|
425
|
+
If True will return channels ordered by frequency bands, so if there
|
|
426
|
+
are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
|
|
427
|
+
channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
|
|
428
|
+
will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
|
|
429
|
+
mne_filter_kwargs: dict
|
|
430
|
+
Keyword arguments for filtering supported by mne.io.Raw.filter().
|
|
431
|
+
Please refer to mne for a detailed explanation.
|
|
432
|
+
"""
|
|
433
|
+
if not frequency_bands:
|
|
434
|
+
raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
|
|
435
|
+
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
|
|
436
|
+
warn(
|
|
437
|
+
"Try to use shorter channel names, since frequency band "
|
|
438
|
+
"annotation requires an estimated 4-8 chars depending on the "
|
|
439
|
+
"frequency ranges. Will truncate to 15 chars (mne max)."
|
|
440
|
+
)
|
|
441
|
+
original_ch_names = raw.ch_names
|
|
442
|
+
all_filtered = []
|
|
443
|
+
for l_freq, h_freq in frequency_bands:
|
|
444
|
+
filtered = raw.copy()
|
|
445
|
+
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
|
|
446
|
+
# mne automatically changes the highpass/lowpass info values
|
|
447
|
+
# when applying filters and channels can't be added if they have
|
|
448
|
+
# different such parameters. Not needed when making picks as
|
|
449
|
+
# high pass is not modified by filter if pick is specified
|
|
450
|
+
|
|
451
|
+
ch_names = filtered.info.ch_names
|
|
452
|
+
ch_types = filtered.info.get_channel_types()
|
|
453
|
+
sampling_freq = filtered.info["sfreq"]
|
|
454
|
+
|
|
455
|
+
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
|
|
456
|
+
|
|
457
|
+
filtered.info = info
|
|
458
|
+
|
|
459
|
+
# add frequency band annotation to channel names
|
|
460
|
+
# truncate to a max of 15 characters, since mne does not allow for more
|
|
461
|
+
filtered.rename_channels(
|
|
462
|
+
{
|
|
463
|
+
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
|
|
464
|
+
for old_name in filtered.ch_names
|
|
465
|
+
}
|
|
466
|
+
)
|
|
467
|
+
all_filtered.append(filtered)
|
|
468
|
+
raw.add_channels(all_filtered)
|
|
469
|
+
if not order_by_frequency_band:
|
|
470
|
+
# order channels by name and not by frequency band:
|
|
471
|
+
# index the list with a stepsize of the number of channels for each of
|
|
472
|
+
# the original channels
|
|
473
|
+
chs_by_freq_band = []
|
|
474
|
+
for i in range(len(original_ch_names)):
|
|
475
|
+
chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
|
|
476
|
+
raw.reorder_channels(chs_by_freq_band)
|
|
477
|
+
if drop_original_signals:
|
|
478
|
+
raw.drop_channels(original_ch_names)
|