braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,478 @@
1
+ """Preprocessors that work on Raw or Epochs objects."""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ from __future__ import annotations
12
+
13
+ import platform
14
+ import sys
15
+ from collections.abc import Iterable
16
+ from functools import partial
17
+ from warnings import warn
18
+
19
+ if sys.version_info < (3, 9):
20
+ from typing import Callable
21
+ else:
22
+ from collections.abc import Callable
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from joblib import Parallel, delayed
27
+ from mne import BaseEpochs, create_info
28
+ from mne.io import BaseRaw
29
+ from numpy.typing import NDArray
30
+
31
+ from braindecode.datasets.base import (
32
+ BaseConcatDataset,
33
+ BaseDataset,
34
+ EEGWindowsDataset,
35
+ WindowsDataset,
36
+ )
37
+ from braindecode.datautil.serialization import (
38
+ _check_save_dir_empty,
39
+ load_concat_dataset,
40
+ )
41
+
42
+
43
+ class Preprocessor(object):
44
+ """Preprocessor for an MNE Raw or Epochs object.
45
+
46
+ Applies the provided preprocessing function to the data of a Raw or Epochs
47
+ object.
48
+ If the function is provided as a string, the method with that name will be
49
+ used (e.g., 'pick_channels', 'filter', etc.).
50
+ If it is provided as a callable and `apply_on_array` is True, the
51
+ `apply_function` method of Raw and Epochs object will be used to apply the
52
+ function on the internal arrays of Raw and Epochs.
53
+ If `apply_on_array` is False, the callable must directly modify the Raw or
54
+ Epochs object (e.g., by calling its method(s) or modifying its attributes).
55
+
56
+ Parameters
57
+ ----------
58
+ fn: str or callable
59
+ If str, the Raw/Epochs object must have a method with that name.
60
+ If callable, directly apply the callable to the object.
61
+ apply_on_array : bool
62
+ Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
63
+ and Epochs object will be used to run `fn` on the underlying arrays
64
+ directly. If False, `fn` must directly modify the Raw or Epochs object.
65
+ kwargs:
66
+ Keyword arguments to be forwarded to the MNE function.
67
+ """
68
+
69
+ def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
70
+ if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
71
+ warn("Preprocessing choices with lambda functions cannot be saved.")
72
+ if callable(fn) and apply_on_array:
73
+ channel_wise = kwargs.pop("channel_wise", False)
74
+ picks = kwargs.pop("picks", None)
75
+ n_jobs = kwargs.pop("n_jobs", 1)
76
+ kwargs = dict(
77
+ fun=partial(fn, **kwargs),
78
+ channel_wise=channel_wise,
79
+ picks=picks,
80
+ n_jobs=n_jobs,
81
+ )
82
+ fn = "apply_function"
83
+ self.fn = fn
84
+ self.kwargs = kwargs
85
+
86
+ def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
87
+ try:
88
+ self._try_apply(raw_or_epochs)
89
+ except RuntimeError:
90
+ # Maybe the function needs the data to be loaded and the data was
91
+ # not loaded yet. Not all MNE functions need data to be loaded,
92
+ # most importantly the 'crop' function can be lazily applied
93
+ # without preloading data which can make the overall preprocessing
94
+ # pipeline substantially faster.
95
+ raw_or_epochs.load_data()
96
+ self._try_apply(raw_or_epochs)
97
+
98
+ def _try_apply(self, raw_or_epochs):
99
+ if callable(self.fn):
100
+ self.fn(raw_or_epochs, **self.kwargs)
101
+ else:
102
+ if not hasattr(raw_or_epochs, self.fn):
103
+ raise AttributeError(f"MNE object does not have a {self.fn} method.")
104
+ getattr(raw_or_epochs, self.fn)(**self.kwargs)
105
+
106
+
107
+ def preprocess(
108
+ concat_ds: BaseConcatDataset,
109
+ preprocessors: list[Preprocessor],
110
+ save_dir: str | None = None,
111
+ overwrite: bool = False,
112
+ n_jobs: int | None = None,
113
+ offset: int = 0,
114
+ copy_data: bool | None = None,
115
+ ):
116
+ """Apply preprocessors to a concat dataset.
117
+
118
+ Parameters
119
+ ----------
120
+ concat_ds: BaseConcatDataset
121
+ A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
122
+ preprocessors: list(Preprocessor)
123
+ List of Preprocessor objects to apply to the dataset.
124
+ save_dir : str | None
125
+ If a string, the preprocessed data will be saved under the specified
126
+ directory and the datasets in ``concat_ds`` will be reloaded with
127
+ `preload=False`.
128
+ overwrite : bool
129
+ When `save_dir` is provided, controls whether to delete the old
130
+ subdirectories that will be written to under `save_dir`. If False and
131
+ the corresponding subdirectories already exist, a ``FileExistsError``
132
+ will be raised.
133
+ n_jobs : int | None
134
+ Number of jobs for parallel execution. See `joblib.Parallel` for
135
+ a more detailed explanation.
136
+ offset : int
137
+ If provided, the integer is added to the id of the dataset in the
138
+ concat. This is useful in the setting of very large datasets, where
139
+ one dataset has to be processed and saved at a time to account for
140
+ its original position.
141
+ copy_data : bool | None
142
+ Whether the data passed to the different jobs should be copied or
143
+ passed by reference.
144
+
145
+ Returns
146
+ -------
147
+ BaseConcatDataset:
148
+ Preprocessed dataset.
149
+ """
150
+ # In case of serialization, make sure directory is available before
151
+ # preprocessing
152
+ if save_dir is not None and not overwrite:
153
+ _check_save_dir_empty(save_dir)
154
+
155
+ if not isinstance(preprocessors, Iterable):
156
+ raise ValueError("preprocessors must be a list of Preprocessor objects.")
157
+ for elem in preprocessors:
158
+ assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
159
+
160
+ parallel_processing = (n_jobs is not None) and (n_jobs != 1)
161
+
162
+ job_prefer = "threads" if platform.system() == "Windows" else None
163
+ list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
164
+ delayed(_preprocess)(
165
+ ds,
166
+ i + offset,
167
+ preprocessors,
168
+ save_dir,
169
+ overwrite,
170
+ copy_data=(
171
+ (parallel_processing and (save_dir is None))
172
+ if copy_data is None
173
+ else copy_data
174
+ ),
175
+ )
176
+ for i, ds in enumerate(concat_ds.datasets)
177
+ )
178
+
179
+ if save_dir is not None: # Reload datasets and replace in concat_ds
180
+ ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
181
+ concat_ds_reloaded = load_concat_dataset(
182
+ save_dir,
183
+ preload=False,
184
+ target_name=None,
185
+ ids_to_load=ids_to_load,
186
+ )
187
+ _replace_inplace(concat_ds, concat_ds_reloaded)
188
+ else:
189
+ if parallel_processing: # joblib made copies
190
+ _replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
191
+ else: # joblib did not make copies, the
192
+ # preprocessing happened in-place
193
+ # Recompute cumulative sizes as transforms might have changed them
194
+ concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
195
+
196
+ return concat_ds
197
+
198
+
199
+ def _replace_inplace(concat_ds, new_concat_ds):
200
+ """Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
201
+
202
+ Parameters
203
+ ----------
204
+ concat_ds : BaseConcatDataset
205
+ Dataset to modify inplace.
206
+ new_concat_ds : BaseConcatDataset
207
+ Dataset to use to modify ``concat_ds``.
208
+ """
209
+ if len(concat_ds.datasets) != len(new_concat_ds.datasets):
210
+ raise ValueError("Both inputs must have the same length.")
211
+ for i in range(len(new_concat_ds.datasets)):
212
+ concat_ds.datasets[i] = new_concat_ds.datasets[i]
213
+
214
+ concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
215
+ preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
216
+ if hasattr(new_concat_ds, preproc_kwargs_attr):
217
+ setattr(
218
+ concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
219
+ )
220
+
221
+
222
+ def _preprocess(
223
+ ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
224
+ ):
225
+ """Apply preprocessor(s) to Raw or Epochs object.
226
+
227
+ Parameters
228
+ ----------
229
+ ds: BaseDataset | WindowsDataset
230
+ Dataset object to preprocess.
231
+ ds_index : int
232
+ Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
233
+ is None.
234
+ preprocessors: list(Preprocessor)
235
+ List of preprocessors to apply to the dataset.
236
+ save_dir : str | None
237
+ If provided, save the preprocessed BaseDataset in the
238
+ specified directory.
239
+ overwrite : bool
240
+ If True, overwrite existing file with the same name.
241
+ copy_data : bool
242
+ First copy the data in case it is preloaded. Necessary for parallel processing to work.
243
+ """
244
+
245
+ def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
246
+ # Copying the data necessary in some scenarios for parallel processing
247
+ # to work when data is in memory (else error about _data not being writeable)
248
+ if raw_or_epochs.preload and copy_data:
249
+ raw_or_epochs._data = raw_or_epochs._data.copy()
250
+ for preproc in preprocessors:
251
+ preproc.apply(raw_or_epochs)
252
+
253
+ if hasattr(ds, "raw"):
254
+ if isinstance(ds, EEGWindowsDataset):
255
+ warn(
256
+ f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
257
+ )
258
+ _preprocess_raw_or_epochs(ds.raw, preprocessors)
259
+ elif hasattr(ds, "windows"):
260
+ _preprocess_raw_or_epochs(ds.windows, preprocessors)
261
+ else:
262
+ raise ValueError(
263
+ "Can only preprocess concatenation of BaseDataset or "
264
+ "WindowsDataset, with either a `raw` or `windows` attribute."
265
+ )
266
+
267
+ # Store preprocessing keyword arguments in the dataset
268
+ _set_preproc_kwargs(ds, preprocessors)
269
+
270
+ if save_dir is not None:
271
+ concat_ds = BaseConcatDataset([ds])
272
+ concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
273
+ else:
274
+ return ds
275
+
276
+
277
+ def _get_preproc_kwargs(preprocessors):
278
+ preproc_kwargs = []
279
+ for p in preprocessors:
280
+ # in case of a mne function, fn is a str, kwargs is a dict
281
+ func_name = p.fn
282
+ func_kwargs = p.kwargs
283
+ # in case of another function
284
+ # if apply_on_array=False
285
+ if callable(p.fn):
286
+ func_name = p.fn.__name__
287
+ # if apply_on_array=True
288
+ else:
289
+ if "fun" in p.fn:
290
+ func_name = p.kwargs["fun"].func.__name__
291
+ func_kwargs = p.kwargs["fun"].keywords
292
+ preproc_kwargs.append((func_name, func_kwargs))
293
+ return preproc_kwargs
294
+
295
+
296
+ def _set_preproc_kwargs(ds, preprocessors):
297
+ """Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
298
+
299
+ Parameters
300
+ ----------
301
+ ds : BaseDataset | WindowsDataset
302
+ Dataset in which to record preprocessing keyword arguments.
303
+ preprocessors : list
304
+ List of preprocessors.
305
+ """
306
+ preproc_kwargs = _get_preproc_kwargs(preprocessors)
307
+ if isinstance(ds, WindowsDataset):
308
+ kind = "window"
309
+ if isinstance(ds, EEGWindowsDataset):
310
+ kind = "raw"
311
+ elif isinstance(ds, BaseDataset):
312
+ kind = "raw"
313
+ else:
314
+ raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
315
+ setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
316
+
317
+
318
+ def exponential_moving_standardize(
319
+ data: NDArray,
320
+ factor_new: float = 0.001,
321
+ init_block_size: int | None = None,
322
+ eps: float = 1e-4,
323
+ ):
324
+ r"""Perform exponential moving standardization.
325
+
326
+ Compute the exponental moving mean :math:`m_t` at time `t` as
327
+ :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
328
+
329
+ Then, compute exponential moving variance :math:`v_t` at time `t` as
330
+ :math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
331
+
332
+ Finally, standardize the data point :math:`x_t` at time `t` as:
333
+ :math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
334
+
335
+
336
+ Parameters
337
+ ----------
338
+ data: np.ndarray (n_channels, n_times)
339
+ factor_new: float
340
+ init_block_size: int
341
+ Standardize data before to this index with regular standardization.
342
+ eps: float
343
+ Stabilizer for division by zero variance.
344
+
345
+ Returns
346
+ -------
347
+ standardized: np.ndarray (n_channels, n_times)
348
+ Standardized data.
349
+ """
350
+ data = data.T
351
+ df = pd.DataFrame(data)
352
+ meaned = df.ewm(alpha=factor_new).mean()
353
+ demeaned = df - meaned
354
+ squared = demeaned * demeaned
355
+ square_ewmed = squared.ewm(alpha=factor_new).mean()
356
+ standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
357
+ standardized = np.array(standardized)
358
+ if init_block_size is not None:
359
+ i_time_axis = 0
360
+ init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
361
+ init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
362
+ init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
363
+ eps, init_std
364
+ )
365
+ standardized[0:init_block_size] = init_block_standardized
366
+ return standardized.T
367
+
368
+
369
+ def exponential_moving_demean(
370
+ data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
371
+ ):
372
+ r"""Perform exponential moving demeanining.
373
+
374
+ Compute the exponental moving mean :math:`m_t` at time `t` as
375
+ :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
376
+
377
+ Deman the data point :math:`x_t` at time `t` as:
378
+ :math:`x'_t=(x_t - m_t)`.
379
+
380
+ Parameters
381
+ ----------
382
+ data: np.ndarray (n_channels, n_times)
383
+ factor_new: float
384
+ init_block_size: int
385
+ Demean data before to this index with regular demeaning.
386
+
387
+ Returns
388
+ -------
389
+ demeaned: np.ndarray (n_channels, n_times)
390
+ Demeaned data.
391
+ """
392
+ data = data.T
393
+ df = pd.DataFrame(data)
394
+ meaned = df.ewm(alpha=factor_new).mean()
395
+ demeaned = df - meaned
396
+ demeaned = np.array(demeaned)
397
+ if init_block_size is not None:
398
+ i_time_axis = 0
399
+ init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
400
+ demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
401
+ return demeaned.T
402
+
403
+
404
+ def filterbank(
405
+ raw: BaseRaw,
406
+ frequency_bands: list[tuple[float, float]],
407
+ drop_original_signals: bool = True,
408
+ order_by_frequency_band: bool = False,
409
+ **mne_filter_kwargs,
410
+ ):
411
+ """Applies multiple bandpass filters to the signals in raw. The raw will be
412
+ modified in-place and number of channels in raw will be updated to
413
+ len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
414
+ drop_original_signals).
415
+
416
+ Parameters
417
+ ----------
418
+ raw: mne.io.Raw
419
+ The raw signals to be filtered.
420
+ frequency_bands: list(tuple)
421
+ The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
422
+ drop_original_signals: bool
423
+ Whether to drop the original unfiltered signals
424
+ order_by_frequency_band: bool
425
+ If True will return channels ordered by frequency bands, so if there
426
+ are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
427
+ channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
428
+ will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
429
+ mne_filter_kwargs: dict
430
+ Keyword arguments for filtering supported by mne.io.Raw.filter().
431
+ Please refer to mne for a detailed explanation.
432
+ """
433
+ if not frequency_bands:
434
+ raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
435
+ if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
436
+ warn(
437
+ "Try to use shorter channel names, since frequency band "
438
+ "annotation requires an estimated 4-8 chars depending on the "
439
+ "frequency ranges. Will truncate to 15 chars (mne max)."
440
+ )
441
+ original_ch_names = raw.ch_names
442
+ all_filtered = []
443
+ for l_freq, h_freq in frequency_bands:
444
+ filtered = raw.copy()
445
+ filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
446
+ # mne automatically changes the highpass/lowpass info values
447
+ # when applying filters and channels can't be added if they have
448
+ # different such parameters. Not needed when making picks as
449
+ # high pass is not modified by filter if pick is specified
450
+
451
+ ch_names = filtered.info.ch_names
452
+ ch_types = filtered.info.get_channel_types()
453
+ sampling_freq = filtered.info["sfreq"]
454
+
455
+ info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
456
+
457
+ filtered.info = info
458
+
459
+ # add frequency band annotation to channel names
460
+ # truncate to a max of 15 characters, since mne does not allow for more
461
+ filtered.rename_channels(
462
+ {
463
+ old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
464
+ for old_name in filtered.ch_names
465
+ }
466
+ )
467
+ all_filtered.append(filtered)
468
+ raw.add_channels(all_filtered)
469
+ if not order_by_frequency_band:
470
+ # order channels by name and not by frequency band:
471
+ # index the list with a stepsize of the number of channels for each of
472
+ # the original channels
473
+ chs_by_freq_band = []
474
+ for i in range(len(original_ch_names)):
475
+ chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
476
+ raw.reorder_channels(chs_by_freq_band)
477
+ if drop_original_signals:
478
+ raw.drop_channels(original_ch_names)