braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,131 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+
10
+ from braindecode.functional import drop_path
11
+
12
+
13
+ class Ensure4d(nn.Module):
14
+ def forward(self, x):
15
+ while len(x.shape) < 4:
16
+ x = x.unsqueeze(-1)
17
+ return x
18
+
19
+
20
+ class Chomp1d(nn.Module):
21
+ def __init__(self, chomp_size):
22
+ super().__init__()
23
+ self.chomp_size = chomp_size
24
+
25
+ def extra_repr(self):
26
+ return "chomp_size={}".format(self.chomp_size)
27
+
28
+ def forward(self, x):
29
+ return x[:, :, : -self.chomp_size].contiguous()
30
+
31
+
32
+ class TimeDistributed(nn.Module):
33
+ """Apply module on multiple windows.
34
+
35
+ Apply the provided module on a sequence of windows and return their
36
+ concatenation.
37
+ Useful with sequence-to-prediction models (e.g. sleep stager which must map
38
+ a sequence of consecutive windows to the label of the middle window in the
39
+ sequence).
40
+
41
+ Parameters
42
+ ----------
43
+ module : nn.Module
44
+ Module to be applied to the input windows. Must accept an input of
45
+ shape (batch_size, n_channels, n_times).
46
+ """
47
+
48
+ def __init__(self, module):
49
+ super().__init__()
50
+ self.module = module
51
+
52
+ def forward(self, x):
53
+ """
54
+ Parameters
55
+ ----------
56
+ x : torch.Tensor
57
+ Sequence of windows, of shape (batch_size, seq_len, n_channels,
58
+ n_times).
59
+
60
+ Returns
61
+ -------
62
+ torch.Tensor
63
+ Shape (batch_size, seq_len, output_size).
64
+ """
65
+ b, s, c, t = x.shape
66
+ out = self.module(x.view(b * s, c, t))
67
+ return out.view(b, s, -1)
68
+
69
+
70
+ class DropPath(nn.Module):
71
+ """Drop paths, also known as Stochastic Depth, per sample.
72
+
73
+ When applied in main path of residual blocks.
74
+
75
+ Parameters:
76
+ -----------
77
+ drop_prob: float (default=None)
78
+ Drop path probability (should be in range 0-1).
79
+
80
+ Notes
81
+ -----
82
+ Code copied and modified from VISSL facebookresearch:
83
+ https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
84
+ All rights reserved.
85
+
86
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
87
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
88
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
89
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
90
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
91
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
92
+ SOFTWARE.
93
+ """
94
+
95
+ def __init__(self, drop_prob=None):
96
+ super(DropPath, self).__init__()
97
+ self.drop_prob = drop_prob
98
+
99
+ def forward(self, x):
100
+ return drop_path(x, self.drop_prob, self.training)
101
+
102
+ # Utility function to print DropPath module
103
+ def extra_repr(self) -> str:
104
+ return f"p={self.drop_prob}"
105
+
106
+
107
+ class SqueezeFinalOutput(nn.Module):
108
+ """
109
+
110
+ Removes empty dimension at end and potentially removes empty time
111
+ dimension. It does not just use squeeze as we never want to remove
112
+ first dimension.
113
+
114
+ Returns
115
+ -------
116
+ x: torch.Tensor
117
+ squeezed tensor
118
+ """
119
+
120
+ def __init__(self):
121
+ super().__init__()
122
+
123
+ self.squeeze = Rearrange("b c t 1 -> b c t")
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ # 1) drop feature dim
127
+ x = self.squeeze(x)
128
+ # 2) drop time dim if singleton
129
+ if x.shape[-1] == 1:
130
+ x = x.squeeze(-1)
131
+ return x
@@ -0,0 +1,49 @@
1
+ from torch import nn
2
+ from torch.nn.utils.parametrize import register_parametrization
3
+
4
+ from braindecode.modules.parametrization import MaxNorm, MaxNormParametrize
5
+
6
+
7
+ class MaxNormLinear(nn.Linear):
8
+ """Linear layer with MaxNorm constraining on weights.
9
+
10
+ Equivalent of Keras tf.keras.Dense(..., kernel_constraint=max_norm())
11
+ [1]_ and [2]_. Implemented as advised in [3]_.
12
+
13
+ Parameters
14
+ ----------
15
+ in_features: int
16
+ Size of each input sample.
17
+ out_features: int
18
+ Size of each output sample.
19
+ bias: bool, optional
20
+ If set to ``False``, the layer will not learn an additive bias.
21
+ Default: ``True``.
22
+
23
+ References
24
+ ----------
25
+ .. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
26
+ .. [2] https://www.tensorflow.org/api_docs/python/tf/keras/constraints/
27
+ MaxNorm
28
+ .. [3] https://discuss.pytorch.org/t/how-to-correctly-implement-in-place-
29
+ max-norm-constraint/96769
30
+ """
31
+
32
+ def __init__(
33
+ self, in_features, out_features, bias=True, max_norm_val=2, eps=1e-5, **kwargs
34
+ ):
35
+ super().__init__(
36
+ in_features=in_features, out_features=out_features, bias=bias, **kwargs
37
+ )
38
+ self._max_norm_val = max_norm_val
39
+ self._eps = eps
40
+ register_parametrization(self, "weight", MaxNorm(self._max_norm_val, self._eps))
41
+
42
+
43
+ class LinearWithConstraint(nn.Linear):
44
+ """Linear layer with max-norm constraint on the weights."""
45
+
46
+ def __init__(self, *args, max_norm=1.0, **kwargs):
47
+ super(LinearWithConstraint, self).__init__(*args, **kwargs)
48
+ self.max_norm = max_norm
49
+ register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MaxNorm(nn.Module):
6
+ def __init__(self, max_norm_val=2.0, eps=1e-5):
7
+ super().__init__()
8
+ self.max_norm_val = max_norm_val
9
+ self.eps = eps
10
+
11
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
12
+ norm = X.norm(2, dim=0, keepdim=True)
13
+ denom = norm.clamp(min=self.max_norm_val / 2)
14
+ number = denom.clamp(max=self.max_norm_val)
15
+ return X * (number / (denom + self.eps))
16
+
17
+ def right_inverse(self, X: torch.Tensor) -> torch.Tensor:
18
+ # Assuming the forward scales X by a factor s,
19
+ # the right inverse would scale it back by 1/s.
20
+ norm = X.norm(2, dim=0, keepdim=True)
21
+ denom = norm.clamp(min=self.max_norm_val / 2)
22
+ number = denom.clamp(max=self.max_norm_val)
23
+ scale = number / (denom + self.eps)
24
+ return X / scale
25
+
26
+
27
+ class MaxNormParametrize(nn.Module):
28
+ """
29
+ Enforce a max‑norm constraint on the rows of a weight tensor via parametrization.
30
+ """
31
+
32
+ def __init__(self, max_norm: float = 1.0):
33
+ super().__init__()
34
+ self.max_norm = max_norm
35
+
36
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
37
+ # Renormalize each "row" (dim=0 slice) to have at most self.max_norm L2-norm
38
+ return X.renorm(p=2, dim=0, maxnorm=self.max_norm)
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class StatLayer(nn.Module):
11
+ """
12
+ Generic layer to compute a statistical function along a specified dimension.
13
+ Parameters
14
+ ----------
15
+ stat_fn : Callable
16
+ A function like torch.mean, torch.std, etc.
17
+ dim : int
18
+ Dimension along which to apply the function.
19
+ keepdim : bool, default=True
20
+ Whether to keep the reduced dimension.
21
+ clamp_range : tuple(float, float), optional
22
+ Used only for functions requiring clamping (e.g., log variance).
23
+ apply_log : bool, default=False
24
+ Whether to apply log after computation (used for LogVarLayer).
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ stat_fn: Callable[..., torch.Tensor],
30
+ dim: int,
31
+ keepdim: bool = True,
32
+ clamp_range: Optional[tuple[float, float]] = None,
33
+ apply_log: bool = False,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.stat_fn = stat_fn
37
+ self.dim = dim
38
+ self.keepdim = keepdim
39
+ self.clamp_range = clamp_range
40
+ self.apply_log = apply_log
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ out = self.stat_fn(x, dim=self.dim, keepdim=self.keepdim)
44
+ if self.clamp_range is not None:
45
+ out = torch.clamp(out, min=self.clamp_range[0], max=self.clamp_range[1])
46
+ if self.apply_log:
47
+ out = torch.log(out)
48
+ return out
49
+
50
+
51
+ # make things more simple
52
+ def _max_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
53
+ return x.max(dim=dim, keepdim=keepdim)[0]
54
+
55
+
56
+ def _power_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
57
+ # compute mean of squared values along `dim`
58
+ return torch.mean(x**2, dim=dim, keepdim=keepdim)
59
+
60
+
61
+ MeanLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.mean)
62
+ MaxLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, _max_fn)
63
+ VarLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.var)
64
+ StdLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.std)
65
+ LogVarLayer: Callable[[int, bool], StatLayer] = partial(
66
+ StatLayer,
67
+ torch.var,
68
+ clamp_range=(1e-6, 1e6),
69
+ apply_log=True,
70
+ )
71
+
72
+ LogPowerLayer: Callable[[int, bool], StatLayer] = partial(
73
+ StatLayer,
74
+ _power_fn,
75
+ clamp_range=(1e-4, 1e4),
76
+ apply_log=True,
77
+ )
@@ -0,0 +1,76 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Hubert Banville <hubert.jbanville@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import numpy as np
7
+ from scipy.special import log_softmax
8
+
9
+
10
+ def _pad_shift_array(x, stride=1):
11
+ """Zero-pad and shift rows of a 3D array.
12
+
13
+ E.g., used to align predictions of corresponding windows in
14
+ sequence-to-sequence models.
15
+
16
+ Parameters
17
+ ----------
18
+ x : np.ndarray
19
+ Array of shape (n_rows, n_classes, n_windows).
20
+ stride : int
21
+ Number of non-overlapping elements between two consecutive sequences.
22
+
23
+ Returns
24
+ -------
25
+ np.ndarray :
26
+ Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
27
+ where each row is obtained by zero-padding the corresponding row in
28
+ ``x`` before and after in the last dimension.
29
+ """
30
+ if x.ndim != 3:
31
+ raise NotImplementedError(
32
+ f"x must be of shape (n_rows, n_classes, n_windows), got {x.shape}"
33
+ )
34
+ x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
35
+ orig_strides = x_padded.strides
36
+ new_strides = (
37
+ orig_strides[0] - stride * orig_strides[2],
38
+ orig_strides[1],
39
+ orig_strides[2],
40
+ )
41
+ return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
42
+
43
+
44
+ def aggregate_probas(logits, n_windows_stride=1):
45
+ """Aggregate predicted probabilities with self-ensembling.
46
+
47
+ Aggregate window-wise predicted probabilities obtained on overlapping
48
+ sequences of windows using multiplicative voting as described in
49
+ [Phan2018]_.
50
+
51
+ Parameters
52
+ ----------
53
+ logits : np.ndarray
54
+ Array of shape (n_sequences, n_classes, n_windows) containing the
55
+ logits (i.e. the raw unnormalized scores for each class) for each
56
+ window of each sequence.
57
+ n_windows_stride : int
58
+ Number of windows between two consecutive sequences. Default is 1
59
+ (maximally overlapping sequences).
60
+
61
+ Returns
62
+ -------
63
+ np.ndarray :
64
+ Array of shape ((n_rows - 1) * stride + n_windows, n_classes)
65
+ containing the aggregated predicted probabilities for each window
66
+ contained in the input sequences.
67
+
68
+ References
69
+ ----------
70
+ .. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
71
+ De Vos, M. (2018). Joint classification and prediction CNN framework
72
+ for automatic sleep stage classification. IEEE Transactions on
73
+ Biomedical Engineering, 66(5), 1285-1296.
74
+ """
75
+ log_probas = log_softmax(logits, axis=1)
76
+ return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
@@ -0,0 +1,73 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Expression(nn.Module):
6
+ """Compute given expression on forward pass.
7
+
8
+ Parameters
9
+ ----------
10
+ expression_fn : callable
11
+ Should accept variable number of objects of type
12
+ `torch.autograd.Variable` to compute its output.
13
+ """
14
+
15
+ def __init__(self, expression_fn):
16
+ super().__init__()
17
+ self.expression_fn = expression_fn
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return self.expression_fn(x)
21
+
22
+ def __repr__(self):
23
+ if hasattr(self.expression_fn, "func") and hasattr(
24
+ self.expression_fn, "kwargs"
25
+ ):
26
+ expression_str = "{:s} {:s}".format(
27
+ self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
28
+ )
29
+ elif hasattr(self.expression_fn, "__name__"):
30
+ expression_str = self.expression_fn.__name__
31
+ else:
32
+ expression_str = repr(self.expression_fn)
33
+ return self.__class__.__name__ + "(expression=%s) " % expression_str
34
+
35
+
36
+ class IntermediateOutputWrapper(nn.Module):
37
+ """Wraps network model such that outputs of intermediate layers can be returned.
38
+ forward() returns list of intermediate activations in a network during forward pass.
39
+
40
+ Parameters
41
+ ----------
42
+ to_select : list
43
+ list of module names for which activation should be returned
44
+ model : model object
45
+ network model
46
+
47
+ Examples
48
+ --------
49
+ >>> model = Deep4Net()
50
+ >>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
51
+ >>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
52
+ """
53
+
54
+ def __init__(self, to_select, model):
55
+ if not len(list(model.children())) == len(list(model.named_children())):
56
+ raise Exception("All modules in model need to have names!")
57
+
58
+ super().__init__()
59
+
60
+ modules_list = model.named_children()
61
+ for key, module in modules_list:
62
+ self.add_module(key, module)
63
+ self._modules[key].load_state_dict(module.state_dict())
64
+ self._to_select = to_select
65
+
66
+ def forward(self, x):
67
+ # Call modules individually and append activation to output if module is in to_select
68
+ o = []
69
+ for name, module in self._modules.items():
70
+ x = module(x)
71
+ if name in self._to_select:
72
+ o.append(x)
73
+ return o
@@ -1,12 +1,37 @@
1
- from .preprocess import (exponential_moving_demean,
2
- exponential_moving_standardize, filterbank,
3
- preprocess, Preprocessor)
4
- from .mne_preprocess import (Resample, DropChannels, SetEEGReference, Filter, Pick, Crop)
5
- from .windowers import (create_windows_from_events, create_fixed_length_windows,
6
- create_windows_from_target_channels)
1
+ from .mne_preprocess import ( # type: ignore[attr-defined]
2
+ Crop,
3
+ DropChannels,
4
+ Filter,
5
+ Pick,
6
+ Resample,
7
+ SetEEGReference,
8
+ )
9
+ from .preprocess import (
10
+ Preprocessor,
11
+ exponential_moving_demean,
12
+ exponential_moving_standardize,
13
+ filterbank,
14
+ preprocess,
15
+ )
16
+ from .windowers import (
17
+ create_fixed_length_windows,
18
+ create_windows_from_events,
19
+ create_windows_from_target_channels,
20
+ )
7
21
 
8
- __all__ = ["exponential_moving_demean", "exponential_moving_standardize",
9
- "filterbank", "preprocess", "Preprocessor", "Resample", "DropChannels",
10
- "SetEEGReference", "Filter", "Pick", "Crop",
11
- "create_windows_from_events", "create_fixed_length_windows",
12
- "create_windows_from_target_channels"]
22
+ __all__ = [
23
+ "exponential_moving_demean",
24
+ "exponential_moving_standardize",
25
+ "filterbank",
26
+ "preprocess",
27
+ "Preprocessor",
28
+ "Resample",
29
+ "DropChannels",
30
+ "SetEEGReference",
31
+ "Filter",
32
+ "Pick",
33
+ "Crop",
34
+ "create_windows_from_events",
35
+ "create_fixed_length_windows",
36
+ "create_windows_from_target_channels",
37
+ ]
@@ -1,11 +1,14 @@
1
1
  """Preprocessor objects based on mne methods."""
2
+
2
3
  # Authors: Bruna Lopes <brunajaflopes@gmail.com>
3
4
  # Bruno Aristimunha <b.aristimunha@gmail.com>
4
5
  #
5
6
  # License: BSD-3
6
7
  import inspect
8
+
7
9
  import mne.io
8
- from braindecode.preprocessing import Preprocessor
10
+
11
+ from braindecode.preprocessing.preprocess import Preprocessor
9
12
  from braindecode.util import _update_moabb_docstring
10
13
 
11
14
 
@@ -31,9 +34,9 @@ def _generate_mne_pre_processor(function):
31
34
  """
32
35
  Generate a class based on an MNE function for preprocessing.
33
36
  """
34
- class_name = ''.join(
35
- word.title() for word in function.__name__.split('_')).replace('Eeg',
36
- 'EEG')
37
+ class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
38
+ "Eeg", "EEG"
39
+ )
37
40
  import_path = f"{function.__module__}.{function.__name__}"
38
41
  doc = f" See more details in {import_path}"
39
42
 
@@ -55,7 +58,7 @@ mne_functions = [
55
58
  mne.io.Raw.filter,
56
59
  mne.io.Raw.crop,
57
60
  mne.io.Raw.pick,
58
- mne.io.Raw.set_eeg_reference
61
+ mne.io.Raw.set_eeg_reference,
59
62
  ]
60
63
 
61
64
  # Automatically generate and add classes to the global namespace
@@ -64,8 +67,11 @@ for function in mne_functions:
64
67
  globals()[class_obj.__name__] = class_obj
65
68
 
66
69
  # Define __all__ based on the generated class names
67
- __all__ = [class_obj.__name__ for class_obj in globals().values() if
68
- isinstance(class_obj, type)]
70
+ __all__ = [
71
+ class_obj.__name__
72
+ for class_obj in globals().values()
73
+ if isinstance(class_obj, type)
74
+ ]
69
75
 
70
76
  # Clean up unnecessary variables
71
77
  del mne_functions, function, class_obj