braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""Preprocessors that work on Raw or Epochs objects.
|
|
2
|
-
"""
|
|
1
|
+
"""Preprocessors that work on Raw or Epochs objects."""
|
|
3
2
|
|
|
4
3
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
5
4
|
# Lukas Gemein <l.gemein@gmail.com>
|
|
@@ -9,19 +8,36 @@
|
|
|
9
8
|
#
|
|
10
9
|
# License: BSD (3-clause)
|
|
11
10
|
|
|
12
|
-
from
|
|
13
|
-
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import platform
|
|
14
|
+
import sys
|
|
14
15
|
from collections.abc import Iterable
|
|
16
|
+
from functools import partial
|
|
17
|
+
from warnings import warn
|
|
18
|
+
|
|
19
|
+
if sys.version_info < (3, 9):
|
|
20
|
+
from typing import Callable
|
|
21
|
+
else:
|
|
22
|
+
from collections.abc import Callable
|
|
15
23
|
|
|
16
24
|
import numpy as np
|
|
17
25
|
import pandas as pd
|
|
18
|
-
from mne import create_info
|
|
19
26
|
from joblib import Parallel, delayed
|
|
20
|
-
|
|
21
|
-
from
|
|
22
|
-
|
|
27
|
+
from mne import BaseEpochs, create_info
|
|
28
|
+
from mne.io import BaseRaw
|
|
29
|
+
from numpy.typing import NDArray
|
|
30
|
+
|
|
31
|
+
from braindecode.datasets.base import (
|
|
32
|
+
BaseConcatDataset,
|
|
33
|
+
BaseDataset,
|
|
34
|
+
EEGWindowsDataset,
|
|
35
|
+
WindowsDataset,
|
|
36
|
+
)
|
|
23
37
|
from braindecode.datautil.serialization import (
|
|
24
|
-
|
|
38
|
+
_check_save_dir_empty,
|
|
39
|
+
load_concat_dataset,
|
|
40
|
+
)
|
|
25
41
|
|
|
26
42
|
|
|
27
43
|
class Preprocessor(object):
|
|
@@ -50,20 +66,24 @@ class Preprocessor(object):
|
|
|
50
66
|
Keyword arguments to be forwarded to the MNE function.
|
|
51
67
|
"""
|
|
52
68
|
|
|
53
|
-
def __init__(self, fn, *, apply_on_array=True, **kwargs):
|
|
54
|
-
if hasattr(fn,
|
|
55
|
-
warn(
|
|
69
|
+
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
|
|
70
|
+
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
|
|
71
|
+
warn("Preprocessing choices with lambda functions cannot be saved.")
|
|
56
72
|
if callable(fn) and apply_on_array:
|
|
57
|
-
channel_wise = kwargs.pop(
|
|
58
|
-
picks = kwargs.pop(
|
|
59
|
-
n_jobs = kwargs.pop(
|
|
60
|
-
kwargs = dict(
|
|
61
|
-
|
|
62
|
-
|
|
73
|
+
channel_wise = kwargs.pop("channel_wise", False)
|
|
74
|
+
picks = kwargs.pop("picks", None)
|
|
75
|
+
n_jobs = kwargs.pop("n_jobs", 1)
|
|
76
|
+
kwargs = dict(
|
|
77
|
+
fun=partial(fn, **kwargs),
|
|
78
|
+
channel_wise=channel_wise,
|
|
79
|
+
picks=picks,
|
|
80
|
+
n_jobs=n_jobs,
|
|
81
|
+
)
|
|
82
|
+
fn = "apply_function"
|
|
63
83
|
self.fn = fn
|
|
64
84
|
self.kwargs = kwargs
|
|
65
85
|
|
|
66
|
-
def apply(self, raw_or_epochs):
|
|
86
|
+
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
|
|
67
87
|
try:
|
|
68
88
|
self._try_apply(raw_or_epochs)
|
|
69
89
|
except RuntimeError:
|
|
@@ -80,13 +100,19 @@ class Preprocessor(object):
|
|
|
80
100
|
self.fn(raw_or_epochs, **self.kwargs)
|
|
81
101
|
else:
|
|
82
102
|
if not hasattr(raw_or_epochs, self.fn):
|
|
83
|
-
raise AttributeError(
|
|
84
|
-
f'MNE object does not have a {self.fn} method.')
|
|
103
|
+
raise AttributeError(f"MNE object does not have a {self.fn} method.")
|
|
85
104
|
getattr(raw_or_epochs, self.fn)(**self.kwargs)
|
|
86
105
|
|
|
87
106
|
|
|
88
|
-
def preprocess(
|
|
89
|
-
|
|
107
|
+
def preprocess(
|
|
108
|
+
concat_ds: BaseConcatDataset,
|
|
109
|
+
preprocessors: list[Preprocessor],
|
|
110
|
+
save_dir: str | None = None,
|
|
111
|
+
overwrite: bool = False,
|
|
112
|
+
n_jobs: int | None = None,
|
|
113
|
+
offset: int = 0,
|
|
114
|
+
copy_data: bool | None = None,
|
|
115
|
+
):
|
|
90
116
|
"""Apply preprocessors to a concat dataset.
|
|
91
117
|
|
|
92
118
|
Parameters
|
|
@@ -107,6 +133,14 @@ def preprocess(concat_ds, preprocessors, save_dir=None, overwrite=False,
|
|
|
107
133
|
n_jobs : int | None
|
|
108
134
|
Number of jobs for parallel execution. See `joblib.Parallel` for
|
|
109
135
|
a more detailed explanation.
|
|
136
|
+
offset : int
|
|
137
|
+
If provided, the integer is added to the id of the dataset in the
|
|
138
|
+
concat. This is useful in the setting of very large datasets, where
|
|
139
|
+
one dataset has to be processed and saved at a time to account for
|
|
140
|
+
its original position.
|
|
141
|
+
copy_data : bool | None
|
|
142
|
+
Whether the data passed to the different jobs should be copied or
|
|
143
|
+
passed by reference.
|
|
110
144
|
|
|
111
145
|
Returns
|
|
112
146
|
-------
|
|
@@ -119,24 +153,37 @@ def preprocess(concat_ds, preprocessors, save_dir=None, overwrite=False,
|
|
|
119
153
|
_check_save_dir_empty(save_dir)
|
|
120
154
|
|
|
121
155
|
if not isinstance(preprocessors, Iterable):
|
|
122
|
-
raise ValueError(
|
|
123
|
-
'preprocessors must be a list of Preprocessor objects.')
|
|
156
|
+
raise ValueError("preprocessors must be a list of Preprocessor objects.")
|
|
124
157
|
for elem in preprocessors:
|
|
125
|
-
assert hasattr(elem,
|
|
126
|
-
'Preprocessor object needs an `apply` method.')
|
|
158
|
+
assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
|
|
127
159
|
|
|
128
160
|
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
|
|
129
161
|
|
|
130
|
-
|
|
162
|
+
job_prefer = "threads" if platform.system() == "Windows" else None
|
|
163
|
+
list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
|
|
131
164
|
delayed(_preprocess)(
|
|
132
|
-
ds,
|
|
133
|
-
|
|
165
|
+
ds,
|
|
166
|
+
i + offset,
|
|
167
|
+
preprocessors,
|
|
168
|
+
save_dir,
|
|
169
|
+
overwrite,
|
|
170
|
+
copy_data=(
|
|
171
|
+
(parallel_processing and (save_dir is None))
|
|
172
|
+
if copy_data is None
|
|
173
|
+
else copy_data
|
|
174
|
+
),
|
|
175
|
+
)
|
|
134
176
|
for i, ds in enumerate(concat_ds.datasets)
|
|
135
177
|
)
|
|
136
178
|
|
|
137
179
|
if save_dir is not None: # Reload datasets and replace in concat_ds
|
|
180
|
+
ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
|
|
138
181
|
concat_ds_reloaded = load_concat_dataset(
|
|
139
|
-
save_dir,
|
|
182
|
+
save_dir,
|
|
183
|
+
preload=False,
|
|
184
|
+
target_name=None,
|
|
185
|
+
ids_to_load=ids_to_load,
|
|
186
|
+
)
|
|
140
187
|
_replace_inplace(concat_ds, concat_ds_reloaded)
|
|
141
188
|
else:
|
|
142
189
|
if parallel_processing: # joblib made copies
|
|
@@ -160,18 +207,21 @@ def _replace_inplace(concat_ds, new_concat_ds):
|
|
|
160
207
|
Dataset to use to modify ``concat_ds``.
|
|
161
208
|
"""
|
|
162
209
|
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
|
|
163
|
-
raise ValueError(
|
|
210
|
+
raise ValueError("Both inputs must have the same length.")
|
|
164
211
|
for i in range(len(new_concat_ds.datasets)):
|
|
165
212
|
concat_ds.datasets[i] = new_concat_ds.datasets[i]
|
|
166
213
|
|
|
167
|
-
concat_kind =
|
|
168
|
-
preproc_kwargs_attr = concat_kind +
|
|
214
|
+
concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
|
|
215
|
+
preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
|
|
169
216
|
if hasattr(new_concat_ds, preproc_kwargs_attr):
|
|
170
|
-
setattr(
|
|
171
|
-
|
|
217
|
+
setattr(
|
|
218
|
+
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
|
|
219
|
+
)
|
|
172
220
|
|
|
173
221
|
|
|
174
|
-
def _preprocess(
|
|
222
|
+
def _preprocess(
|
|
223
|
+
ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
|
|
224
|
+
):
|
|
175
225
|
"""Apply preprocessor(s) to Raw or Epochs object.
|
|
176
226
|
|
|
177
227
|
Parameters
|
|
@@ -195,19 +245,24 @@ def _preprocess(ds, ds_index, preprocessors, save_dir=None, overwrite=False, cop
|
|
|
195
245
|
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
|
|
196
246
|
# Copying the data necessary in some scenarios for parallel processing
|
|
197
247
|
# to work when data is in memory (else error about _data not being writeable)
|
|
198
|
-
if
|
|
248
|
+
if raw_or_epochs.preload and copy_data:
|
|
199
249
|
raw_or_epochs._data = raw_or_epochs._data.copy()
|
|
200
250
|
for preproc in preprocessors:
|
|
201
251
|
preproc.apply(raw_or_epochs)
|
|
202
252
|
|
|
203
|
-
if hasattr(ds,
|
|
253
|
+
if hasattr(ds, "raw"):
|
|
254
|
+
if isinstance(ds, EEGWindowsDataset):
|
|
255
|
+
warn(
|
|
256
|
+
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
|
|
257
|
+
)
|
|
204
258
|
_preprocess_raw_or_epochs(ds.raw, preprocessors)
|
|
205
|
-
elif hasattr(ds,
|
|
259
|
+
elif hasattr(ds, "windows"):
|
|
206
260
|
_preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
207
261
|
else:
|
|
208
262
|
raise ValueError(
|
|
209
|
-
|
|
210
|
-
|
|
263
|
+
"Can only preprocess concatenation of BaseDataset or "
|
|
264
|
+
"WindowsDataset, with either a `raw` or `windows` attribute."
|
|
265
|
+
)
|
|
211
266
|
|
|
212
267
|
# Store preprocessing keyword arguments in the dataset
|
|
213
268
|
_set_preproc_kwargs(ds, preprocessors)
|
|
@@ -231,9 +286,9 @@ def _get_preproc_kwargs(preprocessors):
|
|
|
231
286
|
func_name = p.fn.__name__
|
|
232
287
|
# if apply_on_array=True
|
|
233
288
|
else:
|
|
234
|
-
if
|
|
235
|
-
func_name = p.kwargs[
|
|
236
|
-
func_kwargs = p.kwargs[
|
|
289
|
+
if "fun" in p.fn:
|
|
290
|
+
func_name = p.kwargs["fun"].func.__name__
|
|
291
|
+
func_kwargs = p.kwargs["fun"].keywords
|
|
237
292
|
preproc_kwargs.append((func_name, func_kwargs))
|
|
238
293
|
return preproc_kwargs
|
|
239
294
|
|
|
@@ -250,19 +305,21 @@ def _set_preproc_kwargs(ds, preprocessors):
|
|
|
250
305
|
"""
|
|
251
306
|
preproc_kwargs = _get_preproc_kwargs(preprocessors)
|
|
252
307
|
if isinstance(ds, WindowsDataset):
|
|
253
|
-
kind =
|
|
308
|
+
kind = "window"
|
|
254
309
|
if isinstance(ds, EEGWindowsDataset):
|
|
255
|
-
kind =
|
|
310
|
+
kind = "raw"
|
|
256
311
|
elif isinstance(ds, BaseDataset):
|
|
257
|
-
kind =
|
|
312
|
+
kind = "raw"
|
|
258
313
|
else:
|
|
259
|
-
raise TypeError(
|
|
260
|
-
|
|
261
|
-
setattr(ds, kind + '_preproc_kwargs', preproc_kwargs)
|
|
314
|
+
raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
|
|
315
|
+
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
|
|
262
316
|
|
|
263
317
|
|
|
264
318
|
def exponential_moving_standardize(
|
|
265
|
-
|
|
319
|
+
data: NDArray,
|
|
320
|
+
factor_new: float = 0.001,
|
|
321
|
+
init_block_size: int | None = None,
|
|
322
|
+
eps: float = 1e-4,
|
|
266
323
|
):
|
|
267
324
|
r"""Perform exponential moving standardization.
|
|
268
325
|
|
|
@@ -300,18 +357,18 @@ def exponential_moving_standardize(
|
|
|
300
357
|
standardized = np.array(standardized)
|
|
301
358
|
if init_block_size is not None:
|
|
302
359
|
i_time_axis = 0
|
|
303
|
-
init_mean = np.mean(
|
|
304
|
-
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
data[0:init_block_size], axis=i_time_axis, keepdims=True
|
|
360
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
361
|
+
init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
362
|
+
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
|
|
363
|
+
eps, init_std
|
|
308
364
|
)
|
|
309
|
-
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(eps, init_std)
|
|
310
365
|
standardized[0:init_block_size] = init_block_standardized
|
|
311
366
|
return standardized.T
|
|
312
367
|
|
|
313
368
|
|
|
314
|
-
def exponential_moving_demean(
|
|
369
|
+
def exponential_moving_demean(
|
|
370
|
+
data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
|
|
371
|
+
):
|
|
315
372
|
r"""Perform exponential moving demeanining.
|
|
316
373
|
|
|
317
374
|
Compute the exponental moving mean :math:`m_t` at time `t` as
|
|
@@ -339,15 +396,18 @@ def exponential_moving_demean(data, factor_new=0.001, init_block_size=None):
|
|
|
339
396
|
demeaned = np.array(demeaned)
|
|
340
397
|
if init_block_size is not None:
|
|
341
398
|
i_time_axis = 0
|
|
342
|
-
init_mean = np.mean(
|
|
343
|
-
data[0:init_block_size], axis=i_time_axis, keepdims=True
|
|
344
|
-
)
|
|
399
|
+
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
|
|
345
400
|
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
|
|
346
401
|
return demeaned.T
|
|
347
402
|
|
|
348
403
|
|
|
349
|
-
def filterbank(
|
|
350
|
-
|
|
404
|
+
def filterbank(
|
|
405
|
+
raw: BaseRaw,
|
|
406
|
+
frequency_bands: list[tuple[float, float]],
|
|
407
|
+
drop_original_signals: bool = True,
|
|
408
|
+
order_by_frequency_band: bool = False,
|
|
409
|
+
**mne_filter_kwargs,
|
|
410
|
+
):
|
|
351
411
|
"""Applies multiple bandpass filters to the signals in raw. The raw will be
|
|
352
412
|
modified in-place and number of channels in raw will be updated to
|
|
353
413
|
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
|
|
@@ -371,15 +431,16 @@ def filterbank(raw, frequency_bands, drop_original_signals=True,
|
|
|
371
431
|
Please refer to mne for a detailed explanation.
|
|
372
432
|
"""
|
|
373
433
|
if not frequency_bands:
|
|
374
|
-
raise ValueError(f"Expected at least one frequency band, got"
|
|
375
|
-
f" {frequency_bands}")
|
|
434
|
+
raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
|
|
376
435
|
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
|
|
377
|
-
warn(
|
|
378
|
-
|
|
379
|
-
|
|
436
|
+
warn(
|
|
437
|
+
"Try to use shorter channel names, since frequency band "
|
|
438
|
+
"annotation requires an estimated 4-8 chars depending on the "
|
|
439
|
+
"frequency ranges. Will truncate to 15 chars (mne max)."
|
|
440
|
+
)
|
|
380
441
|
original_ch_names = raw.ch_names
|
|
381
442
|
all_filtered = []
|
|
382
|
-
for
|
|
443
|
+
for l_freq, h_freq in frequency_bands:
|
|
383
444
|
filtered = raw.copy()
|
|
384
445
|
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
|
|
385
446
|
# mne automatically changes the highpass/lowpass info values
|
|
@@ -389,7 +450,7 @@ def filterbank(raw, frequency_bands, drop_original_signals=True,
|
|
|
389
450
|
|
|
390
451
|
ch_names = filtered.info.ch_names
|
|
391
452
|
ch_types = filtered.info.get_channel_types()
|
|
392
|
-
sampling_freq = filtered.info[
|
|
453
|
+
sampling_freq = filtered.info["sfreq"]
|
|
393
454
|
|
|
394
455
|
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
|
|
395
456
|
|
|
@@ -397,9 +458,12 @@ def filterbank(raw, frequency_bands, drop_original_signals=True,
|
|
|
397
458
|
|
|
398
459
|
# add frequency band annotation to channel names
|
|
399
460
|
# truncate to a max of 15 characters, since mne does not allow for more
|
|
400
|
-
filtered.rename_channels(
|
|
401
|
-
|
|
402
|
-
|
|
461
|
+
filtered.rename_channels(
|
|
462
|
+
{
|
|
463
|
+
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
|
|
464
|
+
for old_name in filtered.ch_names
|
|
465
|
+
}
|
|
466
|
+
)
|
|
403
467
|
all_filtered.append(filtered)
|
|
404
468
|
raw.add_channels(all_filtered)
|
|
405
469
|
if not order_by_frequency_band:
|
|
@@ -408,7 +472,7 @@ def filterbank(raw, frequency_bands, drop_original_signals=True,
|
|
|
408
472
|
# the original channels
|
|
409
473
|
chs_by_freq_band = []
|
|
410
474
|
for i in range(len(original_ch_names)):
|
|
411
|
-
chs_by_freq_band.extend(raw.ch_names[i::len(original_ch_names)])
|
|
475
|
+
chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
|
|
412
476
|
raw.reorder_channels(chs_by_freq_band)
|
|
413
477
|
if drop_original_signals:
|
|
414
478
|
raw.drop_channels(original_ch_names)
|