braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,579 @@
1
+ """Preprocessors that work on Raw or Epochs objects."""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ from __future__ import annotations
12
+
13
+ import platform
14
+ import sys
15
+ from collections.abc import Iterable
16
+ from functools import cached_property, partial
17
+ from importlib import import_module
18
+ from inspect import signature
19
+ from warnings import warn
20
+
21
+ if sys.version_info < (3, 9):
22
+ from typing import Callable
23
+ else:
24
+ from collections.abc import Callable
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ from joblib import Parallel, delayed
29
+ from mne import BaseEpochs, create_info
30
+ from mne.io import BaseRaw
31
+ from numpy.typing import NDArray
32
+
33
+ from braindecode.datasets.base import (
34
+ BaseConcatDataset,
35
+ EEGWindowsDataset,
36
+ RawDataset,
37
+ RecordDataset,
38
+ WindowsDataset,
39
+ )
40
+ from braindecode.datautil.serialization import (
41
+ _check_save_dir_empty,
42
+ load_concat_dataset,
43
+ )
44
+
45
+
46
+ class Preprocessor(object):
47
+ """Preprocessor for an MNE Raw or Epochs object.
48
+
49
+ Applies the provided preprocessing function to the data of a Raw or Epochs
50
+ object.
51
+ If the function is provided as a string, the method with that name will be
52
+ used (e.g., 'pick_channels', 'filter', etc.).
53
+ If it is provided as a callable and `apply_on_array` is True, the
54
+ `apply_function` method of Raw and Epochs object will be used to apply the
55
+ function on the internal arrays of Raw and Epochs.
56
+ If `apply_on_array` is False, the callable must directly modify the Raw or
57
+ Epochs object (e.g., by calling its method(s) or modifying its attributes).
58
+
59
+ Parameters
60
+ ----------
61
+ fn : str or callable
62
+ If str, the Raw/Epochs object must have a method with that name.
63
+ If callable, directly apply the callable to the object.
64
+ apply_on_array : bool
65
+ Ignored if ``fn`` is not a callable. If True, the ``apply_function`` of Raw
66
+ and Epochs will be used to run ``fn`` on the underlying arrays directly.
67
+ If False, ``fn`` must directly modify the Raw or Epochs object.
68
+ **kwargs : dict
69
+ Keyword arguments forwarded to the MNE function or callable.
70
+ """
71
+
72
+ def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
73
+ if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
74
+ warn("Preprocessing choices with lambda functions cannot be saved.")
75
+ if apply_on_array and not callable(fn):
76
+ warn(
77
+ "apply_on_array can only be True if fn is a callable function. "
78
+ "Automatically correcting to apply_on_array=False."
79
+ )
80
+ apply_on_array = False
81
+ # We store the exact input parameters. Simpler for serialization.
82
+ self.fn = fn
83
+ self.apply_on_array = apply_on_array
84
+ self.kwargs = kwargs
85
+
86
+ @property
87
+ def _all_attrs(self):
88
+ return ["fn", "apply_on_array", "kwargs"]
89
+
90
+ @property
91
+ def _init_attrs(self):
92
+ return [k for k in self._all_attrs if k in signature(self.__init__).parameters]
93
+
94
+ @cached_property
95
+ def _function(self):
96
+ kwargs = dict(self.kwargs)
97
+ fn = self.fn
98
+ if self.apply_on_array:
99
+ channel_wise = kwargs.pop("channel_wise", False)
100
+ picks = kwargs.pop("picks", None)
101
+ n_jobs = kwargs.pop("n_jobs", 1)
102
+ kwargs = dict(
103
+ fun=partial(fn, **kwargs),
104
+ channel_wise=channel_wise,
105
+ picks=picks,
106
+ n_jobs=n_jobs,
107
+ )
108
+ fn = "apply_function"
109
+
110
+ if callable(fn):
111
+ return partial(fn, **kwargs)
112
+ return partial(self._apply_str, fn=fn, **kwargs)
113
+
114
+ @staticmethod
115
+ def _apply_str(raw_or_epochs: BaseRaw | BaseEpochs, fn: str, **kwargs):
116
+ if not hasattr(raw_or_epochs, fn):
117
+ raise AttributeError(f"MNE object does not have a {fn} method.")
118
+ return getattr(raw_or_epochs, fn)(**kwargs)
119
+
120
+ def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
121
+ function = self._function
122
+ try:
123
+ result = function(raw_or_epochs)
124
+ except RuntimeError:
125
+ # Maybe the function needs the data to be loaded and the data was
126
+ # not loaded yet. Not all MNE functions need data to be loaded,
127
+ # most importantly the 'crop' function can be lazily applied
128
+ # without preloading data which can make the overall preprocessing
129
+ # pipeline substantially faster.
130
+ raw_or_epochs.load_data()
131
+ result = function(raw_or_epochs)
132
+ if result is not None:
133
+ return result
134
+ return raw_or_epochs
135
+
136
+ def serialize(self):
137
+ """Return a serializable representation of the Preprocessor.
138
+
139
+ Returns
140
+ -------
141
+ dict
142
+ Dictionary with keys 'fn' and 'kwargs' representing the
143
+ Preprocessor.
144
+ """
145
+ out = {k: getattr(self, k) for k in self._init_attrs}
146
+ if "fn" in out and callable(self.fn):
147
+ out["fn"] = self.fn.__module__ + "." + self.fn.__name__
148
+ out["__class_path__"] = (
149
+ self.__class__.__module__ + "." + self.__class__.__name__
150
+ )
151
+ if "kwargs" not in out and self.kwargs:
152
+ out["kwargs"] = self.kwargs
153
+ return out
154
+
155
+ @classmethod
156
+ def deserialize(cls_parent, data: dict):
157
+ """Create a Preprocessor from its serializable representation.
158
+
159
+ Parameters
160
+ ----------
161
+ data : dict
162
+ Dictionary with keys 'fn' and 'kwargs' representing the
163
+ Preprocessor.
164
+ Returns
165
+ -------
166
+ Preprocessor
167
+ The deserialized Preprocessor object.
168
+ """
169
+ class_path = data.pop("__class_path__")
170
+ cls_name = class_path.split(".")[-1]
171
+ cls_module_name = ".".join(class_path.split(".")[:-1])
172
+ cls_module = import_module(cls_module_name)
173
+ cls = getattr(cls_module, cls_name)
174
+
175
+ kwargs = data.pop("kwargs") if "kwargs" in data else {}
176
+
177
+ fn = data.get("fn", None)
178
+ if fn is not None and "." in fn: # callable function
179
+ fn_name = fn.split(".")[-1]
180
+ module_name = ".".join(fn.split(".")[:-1])
181
+ module = import_module(module_name)
182
+ data["fn"] = getattr(module, fn_name)
183
+
184
+ return cls(**data, **kwargs)
185
+
186
+ def __repr__(self):
187
+ cls_name = self.__class__.__name__
188
+ args_str = ", ".join(
189
+ f"{k}={getattr(self, k).__repr__()}" for k in self._init_attrs
190
+ )
191
+ return f"{cls_name}({args_str})"
192
+
193
+ def _same_attr(self, other, attr):
194
+ a = getattr(self, attr)
195
+ b = getattr(other, attr)
196
+ if attr == "fn" and callable(a):
197
+ return a.__module__ == b.__module__ and a.__name__ == b.__name__
198
+ if isinstance(a, np.ndarray):
199
+ return np.array_equal(a, b)
200
+ return a == b
201
+
202
+ def __eq__(self, other):
203
+ if not isinstance(other, Preprocessor):
204
+ return False
205
+ return all(self._same_attr(other, attr) for attr in self._all_attrs) and (
206
+ self.__class__ == other.__class__
207
+ )
208
+
209
+
210
+ def preprocess(
211
+ concat_ds: BaseConcatDataset,
212
+ preprocessors: list[Preprocessor],
213
+ save_dir: str | None = None,
214
+ overwrite: bool = False,
215
+ n_jobs: int | None = None,
216
+ offset: int = 0,
217
+ copy_data: bool | None = None,
218
+ parallel_kwargs: dict | None = None,
219
+ ):
220
+ """Apply preprocessors to a concat dataset.
221
+
222
+ Parameters
223
+ ----------
224
+ concat_ds : BaseConcatDataset
225
+ A concat of ``RecordDataset`` to be preprocessed.
226
+ preprocessors : list of Preprocessor
227
+ Preprocessor objects to apply to each dataset.
228
+ save_dir : str | None
229
+ If provided, save preprocessed data under this directory and reload
230
+ datasets in ``concat_ds`` with ``preload=False``.
231
+ overwrite : bool
232
+ When ``save_dir`` is provided, controls whether to delete the old
233
+ subdirectories that will be written to under ``save_dir``. If False and
234
+ the corresponding subdirectories already exist, a ``FileExistsError`` is raised.
235
+ n_jobs : int | None
236
+ Number of jobs for parallel execution. See ``joblib.Parallel`` for details.
237
+ offset : int
238
+ Integer added to the dataset id in the concat. Useful when processing
239
+ and saving very large datasets in chunks to preserve original positions.
240
+ copy_data : bool | None
241
+ Whether the data passed to parallel jobs should be copied or passed by reference.
242
+ parallel_kwargs : dict | None
243
+ Additional keyword arguments forwarded to ``joblib.Parallel``.
244
+ Defaults to None (equivalent to ``{}``).
245
+ See https://joblib.readthedocs.io/en/stable/generated/joblib.Parallel.html for details.
246
+
247
+ Returns
248
+ -------
249
+ BaseConcatDataset
250
+ Preprocessed dataset.
251
+ """
252
+ # In case of serialization, make sure directory is available before
253
+ # preprocessing
254
+ if save_dir is not None and not overwrite:
255
+ _check_save_dir_empty(save_dir)
256
+
257
+ if not isinstance(preprocessors, Iterable):
258
+ raise ValueError("preprocessors must be a list of Preprocessor objects.")
259
+ for elem in preprocessors:
260
+ assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
261
+
262
+ parallel_processing = (n_jobs is not None) and (n_jobs != 1)
263
+
264
+ parallel_params = {} if parallel_kwargs is None else dict(parallel_kwargs)
265
+ parallel_params.setdefault(
266
+ "prefer", "threads" if platform.system() == "Windows" else None
267
+ )
268
+
269
+ list_of_ds = Parallel(n_jobs=n_jobs, **parallel_params)(
270
+ delayed(_preprocess)(
271
+ ds,
272
+ i + offset,
273
+ preprocessors,
274
+ save_dir,
275
+ overwrite,
276
+ copy_data=(
277
+ (parallel_processing and (save_dir is None))
278
+ if copy_data is None
279
+ else copy_data
280
+ ),
281
+ )
282
+ for i, ds in enumerate(concat_ds.datasets)
283
+ )
284
+
285
+ if save_dir is not None: # Reload datasets and replace in concat_ds
286
+ ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
287
+ concat_ds_reloaded = load_concat_dataset(
288
+ save_dir,
289
+ preload=False,
290
+ target_name=None,
291
+ ids_to_load=ids_to_load,
292
+ )
293
+ _replace_inplace(concat_ds, concat_ds_reloaded)
294
+ else:
295
+ if parallel_processing: # joblib made copies
296
+ _replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
297
+ else: # joblib did not make copies, the
298
+ # preprocessing happened in-place
299
+ # Recompute cumulative sizes as transforms might have changed them
300
+ concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
301
+
302
+ return concat_ds
303
+
304
+
305
+ def _replace_inplace(concat_ds, new_concat_ds):
306
+ """Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
307
+
308
+ Parameters
309
+ ----------
310
+ concat_ds : BaseConcatDataset
311
+ Dataset to modify inplace.
312
+ new_concat_ds : BaseConcatDataset
313
+ Dataset to use to modify ``concat_ds``.
314
+ """
315
+ if len(concat_ds.datasets) != len(new_concat_ds.datasets):
316
+ raise ValueError("Both inputs must have the same length.")
317
+ for i in range(len(new_concat_ds.datasets)):
318
+ concat_ds.datasets[i] = new_concat_ds.datasets[i]
319
+
320
+ concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
321
+ preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
322
+ if hasattr(new_concat_ds, preproc_kwargs_attr):
323
+ setattr(
324
+ concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
325
+ )
326
+
327
+ # Recompute cumulative_sizes after replacing datasets
328
+ concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
329
+
330
+
331
+ def _preprocess(
332
+ ds: RecordDataset,
333
+ ds_index,
334
+ preprocessors,
335
+ save_dir=None,
336
+ overwrite=False,
337
+ copy_data=False,
338
+ ):
339
+ """Apply preprocessor(s) to Raw or Epochs object.
340
+
341
+ Parameters
342
+ ----------
343
+ ds: RecordDataset
344
+ Dataset object to preprocess.
345
+ ds_index : int
346
+ Index of the ``RecordDataset`` in its ``BaseConcatDataset``. Ignored if save_dir
347
+ is None.
348
+ preprocessors: list(Preprocessor)
349
+ List of preprocessors to apply to the dataset.
350
+ save_dir : str | None
351
+ If provided, save the preprocessed RecordDataset in the
352
+ specified directory.
353
+ overwrite : bool
354
+ If True, overwrite existing file with the same name.
355
+ copy_data : bool
356
+ First copy the data in case it is preloaded. Necessary for parallel processing to work.
357
+ """
358
+
359
+ def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
360
+ # Copying the data necessary in some scenarios for parallel processing
361
+ # to work when data is in memory (else error about _data not being writeable)
362
+ if raw_or_epochs.preload and copy_data:
363
+ raw_or_epochs._data = raw_or_epochs._data.copy()
364
+ for preproc in preprocessors:
365
+ raw_or_epochs = preproc.apply(raw_or_epochs)
366
+ return raw_or_epochs
367
+
368
+ if hasattr(ds, "raw"):
369
+ if isinstance(ds, EEGWindowsDataset):
370
+ warn(
371
+ f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
372
+ )
373
+ processed = _preprocess_raw_or_epochs(ds.raw, preprocessors)
374
+ if processed is not ds.raw:
375
+ ds.raw = processed
376
+ elif hasattr(ds, "windows"):
377
+ processed = _preprocess_raw_or_epochs(ds.windows, preprocessors)
378
+ if processed is not ds.windows:
379
+ ds.windows = processed
380
+ else:
381
+ raise ValueError(
382
+ "Can only preprocess concatenation of RecordDataset, "
383
+ "with either a `raw` or `windows` attribute."
384
+ )
385
+
386
+ # Store preprocessing keyword arguments in the dataset
387
+ _set_preproc_kwargs(ds, preprocessors)
388
+
389
+ if save_dir is not None:
390
+ concat_ds = BaseConcatDataset([ds])
391
+ concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
392
+ else:
393
+ return ds
394
+
395
+
396
+ def _set_preproc_kwargs(ds, preprocessors):
397
+ """Record preprocessing keyword arguments in RecordDataset.
398
+
399
+ Parameters
400
+ ----------
401
+ ds : RecordDataset
402
+ Dataset in which to record preprocessing keyword arguments.
403
+ preprocessors : list
404
+ List of preprocessors.
405
+ """
406
+ preproc_kwargs = [p.serialize() for p in preprocessors]
407
+ if isinstance(ds, WindowsDataset):
408
+ kind = "window"
409
+ elif isinstance(ds, EEGWindowsDataset):
410
+ kind = "raw"
411
+ elif isinstance(ds, RawDataset):
412
+ kind = "raw"
413
+ else:
414
+ raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
415
+ old_preproc_kwargs = getattr(ds, kind + "_preproc_kwargs")
416
+ old_preproc_kwargs.extend(preproc_kwargs)
417
+
418
+
419
+ def exponential_moving_standardize(
420
+ data: NDArray,
421
+ factor_new: float = 0.001,
422
+ init_block_size: int | None = None,
423
+ eps: float = 1e-4,
424
+ ):
425
+ r"""Perform exponential moving standardization.
426
+
427
+ Compute the exponental moving mean :math:`m_t` at time `t` as
428
+ :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
429
+
430
+ Then, compute exponential moving variance :math:`v_t` at time `t` as
431
+ :math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
432
+
433
+ Finally, standardize the data point :math:`x_t` at time `t` as:
434
+ :math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
435
+
436
+
437
+ Parameters
438
+ ----------
439
+ data: np.ndarray (n_channels, n_times)
440
+ factor_new: float
441
+ init_block_size: int
442
+ Standardize data before to this index with regular standardization.
443
+ eps: float
444
+ Stabilizer for division by zero variance.
445
+
446
+ Returns
447
+ -------
448
+ standardized: np.ndarray (n_channels, n_times)
449
+ Standardized data.
450
+ """
451
+ data = data.T
452
+ df = pd.DataFrame(data)
453
+ meaned = df.ewm(alpha=factor_new).mean()
454
+ demeaned = df - meaned
455
+ squared = demeaned * demeaned
456
+ square_ewmed = squared.ewm(alpha=factor_new).mean()
457
+ standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
458
+ standardized = np.array(standardized)
459
+ if init_block_size is not None:
460
+ i_time_axis = 0
461
+ init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
462
+ init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
463
+ init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
464
+ eps, init_std
465
+ )
466
+ standardized[0:init_block_size] = init_block_standardized
467
+ return standardized.T
468
+
469
+
470
+ def exponential_moving_demean(
471
+ data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
472
+ ):
473
+ r"""Perform exponential moving demeanining.
474
+
475
+ Compute the exponental moving mean :math:`m_t` at time `t` as
476
+ :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
477
+
478
+ Deman the data point :math:`x_t` at time `t` as:
479
+ :math:`x'_t=(x_t - m_t)`.
480
+
481
+ Parameters
482
+ ----------
483
+ data: np.ndarray (n_channels, n_times)
484
+ factor_new: float
485
+ init_block_size: int
486
+ Demean data before to this index with regular demeaning.
487
+
488
+ Returns
489
+ -------
490
+ demeaned: np.ndarray (n_channels, n_times)
491
+ Demeaned data.
492
+ """
493
+ data = data.T
494
+ df = pd.DataFrame(data)
495
+ meaned = df.ewm(alpha=factor_new).mean()
496
+ demeaned = df - meaned
497
+ demeaned = np.array(demeaned)
498
+ if init_block_size is not None:
499
+ i_time_axis = 0
500
+ init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
501
+ demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
502
+ return demeaned.T
503
+
504
+
505
+ def filterbank(
506
+ raw: BaseRaw,
507
+ frequency_bands: list[tuple[float, float]],
508
+ drop_original_signals: bool = True,
509
+ order_by_frequency_band: bool = False,
510
+ **mne_filter_kwargs,
511
+ ):
512
+ """Applies multiple bandpass filters to the signals in raw. The raw will be
513
+ modified in-place and number of channels in raw will be updated to
514
+ len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
515
+ drop_original_signals).
516
+
517
+ Parameters
518
+ ----------
519
+ raw: mne.io.Raw
520
+ The raw signals to be filtered.
521
+ frequency_bands: list(tuple)
522
+ The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
523
+ drop_original_signals: bool
524
+ Whether to drop the original unfiltered signals
525
+ order_by_frequency_band: bool
526
+ If True will return channels ordered by frequency bands, so if there
527
+ are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
528
+ channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
529
+ will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
530
+ mne_filter_kwargs: dict
531
+ Keyword arguments for filtering supported by mne.io.Raw.filter().
532
+ Please refer to mne for a detailed explanation.
533
+ """
534
+ if not frequency_bands:
535
+ raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
536
+ if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
537
+ warn(
538
+ "Try to use shorter channel names, since frequency band "
539
+ "annotation requires an estimated 4-8 chars depending on the "
540
+ "frequency ranges. Will truncate to 15 chars (mne max)."
541
+ )
542
+ original_ch_names = raw.ch_names
543
+ all_filtered = []
544
+ for l_freq, h_freq in frequency_bands:
545
+ filtered = raw.copy()
546
+ filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
547
+ # mne automatically changes the highpass/lowpass info values
548
+ # when applying filters and channels can't be added if they have
549
+ # different such parameters. Not needed when making picks as
550
+ # high pass is not modified by filter if pick is specified
551
+
552
+ ch_names = filtered.info.ch_names
553
+ ch_types = filtered.info.get_channel_types()
554
+ sampling_freq = filtered.info["sfreq"]
555
+
556
+ info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
557
+
558
+ filtered.info = info
559
+
560
+ # add frequency band annotation to channel names
561
+ # truncate to a max of 15 characters, since mne does not allow for more
562
+ filtered.rename_channels(
563
+ {
564
+ old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
565
+ for old_name in filtered.ch_names
566
+ }
567
+ )
568
+ all_filtered.append(filtered)
569
+ raw.add_channels(all_filtered)
570
+ if not order_by_frequency_band:
571
+ # order channels by name and not by frequency band:
572
+ # index the list with a stepsize of the number of channels for each of
573
+ # the original channels
574
+ chs_by_freq_band = []
575
+ for i in range(len(original_ch_names)):
576
+ chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
577
+ raw.reorder_channels(chs_by_freq_band)
578
+ if drop_original_signals:
579
+ raw.drop_channels(original_ch_names)