braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops.layers.torch import Rearrange
10
+ from torch import Tensor, nn
11
+
12
+ from braindecode.functional import drop_path, safe_log
13
+
14
+
15
+ class Ensure4d(nn.Module):
16
+ def forward(self, x):
17
+ while len(x.shape) < 4:
18
+ x = x.unsqueeze(-1)
19
+ return x
20
+
21
+
22
+ class Chomp1d(nn.Module):
23
+ def __init__(self, chomp_size):
24
+ super().__init__()
25
+ self.chomp_size = chomp_size
26
+
27
+ def extra_repr(self):
28
+ return "chomp_size={}".format(self.chomp_size)
29
+
30
+ def forward(self, x):
31
+ return x[:, :, : -self.chomp_size].contiguous()
32
+
33
+
34
+ class TimeDistributed(nn.Module):
35
+ """Apply module on multiple windows.
36
+
37
+ Apply the provided module on a sequence of windows and return their
38
+ concatenation.
39
+ Useful with sequence-to-prediction models (e.g. sleep stager which must map
40
+ a sequence of consecutive windows to the label of the middle window in the
41
+ sequence).
42
+
43
+ Parameters
44
+ ----------
45
+ module : nn.Module
46
+ Module to be applied to the input windows. Must accept an input of
47
+ shape (batch_size, n_channels, n_times).
48
+ """
49
+
50
+ def __init__(self, module):
51
+ super().__init__()
52
+ self.module = module
53
+
54
+ def forward(self, x):
55
+ """
56
+ Parameters
57
+ ----------
58
+ x : torch.Tensor
59
+ Sequence of windows, of shape (batch_size, seq_len, n_channels,
60
+ n_times).
61
+
62
+ Returns
63
+ -------
64
+ torch.Tensor
65
+ Shape (batch_size, seq_len, output_size).
66
+ """
67
+ b, s, c, t = x.shape
68
+ out = self.module(x.view(b * s, c, t))
69
+ return out.view(b, s, -1)
70
+
71
+
72
+ class DropPath(nn.Module):
73
+ """Drop paths, also known as Stochastic Depth, per sample.
74
+
75
+ When applied in main path of residual blocks.
76
+
77
+ Parameters:
78
+ -----------
79
+ drop_prob: float (default=None)
80
+ Drop path probability (should be in range 0-1).
81
+
82
+ Notes
83
+ -----
84
+ Code copied and modified from VISSL facebookresearch:
85
+ https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
86
+ All rights reserved.
87
+
88
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
89
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
90
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
91
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
92
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
93
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
94
+ SOFTWARE.
95
+ """
96
+
97
+ def __init__(self, drop_prob=None):
98
+ super(DropPath, self).__init__()
99
+ self.drop_prob = drop_prob
100
+
101
+ def forward(self, x):
102
+ return drop_path(x, self.drop_prob, self.training)
103
+
104
+ # Utility function to print DropPath module
105
+ def extra_repr(self) -> str:
106
+ return f"p={self.drop_prob}"
107
+
108
+
109
+ class SqueezeFinalOutput(nn.Module):
110
+ """
111
+
112
+ Removes empty dimension at end and potentially removes empty time
113
+ dimension. It does not just use squeeze as we never want to remove
114
+ first dimension.
115
+
116
+ Returns
117
+ -------
118
+ x: torch.Tensor
119
+ squeezed tensor
120
+ """
121
+
122
+ def __init__(self):
123
+ super().__init__()
124
+
125
+ self.squeeze = Rearrange("b c t 1 -> b c t")
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ # 1) drop feature dim
129
+ x = self.squeeze(x)
130
+ # 2) drop time dim if singleton
131
+ if x.shape[-1] == 1:
132
+ x = x.squeeze(-1)
133
+ return x
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+ from torch.nn.utils.parametrize import register_parametrization
4
+
5
+ from braindecode.modules.parametrization import MaxNorm, MaxNormParametrize
6
+
7
+
8
+ class MaxNormLinear(nn.Linear):
9
+ """Linear layer with MaxNorm constraining on weights.
10
+
11
+ Equivalent of Keras tf.keras.Dense(..., kernel_constraint=max_norm())
12
+ [1]_ and [2]_. Implemented as advised in [3]_.
13
+
14
+ Parameters
15
+ ----------
16
+ in_features: int
17
+ Size of each input sample.
18
+ out_features: int
19
+ Size of each output sample.
20
+ bias: bool, optional
21
+ If set to ``False``, the layer will not learn an additive bias.
22
+ Default: ``True``.
23
+
24
+ References
25
+ ----------
26
+ .. [1] https://keras.io/api/layers/core_layers/dense/#dense-class
27
+ .. [2] https://www.tensorflow.org/api_docs/python/tf/keras/constraints/
28
+ MaxNorm
29
+ .. [3] https://discuss.pytorch.org/t/how-to-correctly-implement-in-place-
30
+ max-norm-constraint/96769
31
+ """
32
+
33
+ def __init__(
34
+ self, in_features, out_features, bias=True, max_norm_val=2, eps=1e-5, **kwargs
35
+ ):
36
+ super().__init__(
37
+ in_features=in_features, out_features=out_features, bias=bias, **kwargs
38
+ )
39
+ self._max_norm_val = max_norm_val
40
+ self._eps = eps
41
+ register_parametrization(self, "weight", MaxNorm(self._max_norm_val, self._eps))
42
+
43
+
44
+ class LinearWithConstraint(nn.Linear):
45
+ """Linear layer with max-norm constraint on the weights."""
46
+
47
+ def __init__(self, *args, max_norm=1.0, **kwargs):
48
+ super(LinearWithConstraint, self).__init__(*args, **kwargs)
49
+ self.max_norm = max_norm
50
+ register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MaxNorm(nn.Module):
6
+ def __init__(self, max_norm_val=2.0, eps=1e-5):
7
+ super().__init__()
8
+ self.max_norm_val = max_norm_val
9
+ self.eps = eps
10
+
11
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
12
+ norm = X.norm(2, dim=0, keepdim=True)
13
+ denom = norm.clamp(min=self.max_norm_val / 2)
14
+ number = denom.clamp(max=self.max_norm_val)
15
+ return X * (number / (denom + self.eps))
16
+
17
+ def right_inverse(self, X: torch.Tensor) -> torch.Tensor:
18
+ # Assuming the forward scales X by a factor s,
19
+ # the right inverse would scale it back by 1/s.
20
+ norm = X.norm(2, dim=0, keepdim=True)
21
+ denom = norm.clamp(min=self.max_norm_val / 2)
22
+ number = denom.clamp(max=self.max_norm_val)
23
+ scale = number / (denom + self.eps)
24
+ return X / scale
25
+
26
+
27
+ class MaxNormParametrize(nn.Module):
28
+ """
29
+ Enforce a max‑norm constraint on the rows of a weight tensor via parametrization.
30
+ """
31
+
32
+ def __init__(self, max_norm: float = 1.0):
33
+ super().__init__()
34
+ self.max_norm = max_norm
35
+
36
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
37
+ # Renormalize each "row" (dim=0 slice) to have at most self.max_norm L2-norm
38
+ return X.renorm(p=2, dim=0, maxnorm=self.max_norm)
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class StatLayer(nn.Module):
11
+ """
12
+ Generic layer to compute a statistical function along a specified dimension.
13
+ Parameters
14
+ ----------
15
+ stat_fn : Callable
16
+ A function like torch.mean, torch.std, etc.
17
+ dim : int
18
+ Dimension along which to apply the function.
19
+ keepdim : bool, default=True
20
+ Whether to keep the reduced dimension.
21
+ clamp_range : tuple(float, float), optional
22
+ Used only for functions requiring clamping (e.g., log variance).
23
+ apply_log : bool, default=False
24
+ Whether to apply log after computation (used for LogVarLayer).
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ stat_fn: Callable[..., torch.Tensor],
30
+ dim: int,
31
+ keepdim: bool = True,
32
+ clamp_range: Optional[tuple[float, float]] = None,
33
+ apply_log: bool = False,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.stat_fn = stat_fn
37
+ self.dim = dim
38
+ self.keepdim = keepdim
39
+ self.clamp_range = clamp_range
40
+ self.apply_log = apply_log
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ out = self.stat_fn(x, dim=self.dim, keepdim=self.keepdim)
44
+ if self.clamp_range is not None:
45
+ out = torch.clamp(out, min=self.clamp_range[0], max=self.clamp_range[1])
46
+ if self.apply_log:
47
+ out = torch.log(out)
48
+ return out
49
+
50
+
51
+ # make things more simple
52
+ def _max_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
53
+ return x.max(dim=dim, keepdim=keepdim)[0]
54
+
55
+
56
+ def _power_fn(x: torch.Tensor, dim: int, keepdim: bool) -> torch.Tensor:
57
+ # compute mean of squared values along `dim`
58
+ return torch.mean(x**2, dim=dim, keepdim=keepdim)
59
+
60
+
61
+ MeanLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.mean)
62
+ MaxLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, _max_fn)
63
+ VarLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.var)
64
+ StdLayer: Callable[[int, bool], StatLayer] = partial(StatLayer, torch.std)
65
+ LogVarLayer: Callable[[int, bool], StatLayer] = partial(
66
+ StatLayer,
67
+ torch.var,
68
+ clamp_range=(1e-6, 1e6),
69
+ apply_log=True,
70
+ )
71
+
72
+ LogPowerLayer: Callable[[int, bool], StatLayer] = partial(
73
+ StatLayer,
74
+ _power_fn,
75
+ clamp_range=(1e-4, 1e4),
76
+ apply_log=True,
77
+ )
@@ -0,0 +1,77 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Hubert Banville <hubert.jbanville@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import numpy as np
7
+ import torch
8
+ from scipy.special import log_softmax
9
+
10
+
11
+ def _pad_shift_array(x, stride=1):
12
+ """Zero-pad and shift rows of a 3D array.
13
+
14
+ E.g., used to align predictions of corresponding windows in
15
+ sequence-to-sequence models.
16
+
17
+ Parameters
18
+ ----------
19
+ x : np.ndarray
20
+ Array of shape (n_rows, n_classes, n_windows).
21
+ stride : int
22
+ Number of non-overlapping elements between two consecutive sequences.
23
+
24
+ Returns
25
+ -------
26
+ np.ndarray :
27
+ Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
28
+ where each row is obtained by zero-padding the corresponding row in
29
+ ``x`` before and after in the last dimension.
30
+ """
31
+ if x.ndim != 3:
32
+ raise NotImplementedError(
33
+ f"x must be of shape (n_rows, n_classes, n_windows), got {x.shape}"
34
+ )
35
+ x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
36
+ orig_strides = x_padded.strides
37
+ new_strides = (
38
+ orig_strides[0] - stride * orig_strides[2],
39
+ orig_strides[1],
40
+ orig_strides[2],
41
+ )
42
+ return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
43
+
44
+
45
+ def aggregate_probas(logits, n_windows_stride=1):
46
+ """Aggregate predicted probabilities with self-ensembling.
47
+
48
+ Aggregate window-wise predicted probabilities obtained on overlapping
49
+ sequences of windows using multiplicative voting as described in
50
+ [Phan2018]_.
51
+
52
+ Parameters
53
+ ----------
54
+ logits : np.ndarray
55
+ Array of shape (n_sequences, n_classes, n_windows) containing the
56
+ logits (i.e. the raw unnormalized scores for each class) for each
57
+ window of each sequence.
58
+ n_windows_stride : int
59
+ Number of windows between two consecutive sequences. Default is 1
60
+ (maximally overlapping sequences).
61
+
62
+ Returns
63
+ -------
64
+ np.ndarray :
65
+ Array of shape ((n_rows - 1) * stride + n_windows, n_classes)
66
+ containing the aggregated predicted probabilities for each window
67
+ contained in the input sequences.
68
+
69
+ References
70
+ ----------
71
+ .. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
72
+ De Vos, M. (2018). Joint classification and prediction CNN framework
73
+ for automatic sleep stage classification. IEEE Transactions on
74
+ Biomedical Engineering, 66(5), 1285-1296.
75
+ """
76
+ log_probas = log_softmax(logits, axis=1)
77
+ return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
@@ -0,0 +1,75 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class Expression(nn.Module):
8
+ """Compute given expression on forward pass.
9
+
10
+ Parameters
11
+ ----------
12
+ expression_fn : callable
13
+ Should accept variable number of objects of type
14
+ `torch.autograd.Variable` to compute its output.
15
+ """
16
+
17
+ def __init__(self, expression_fn):
18
+ super().__init__()
19
+ self.expression_fn = expression_fn
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.expression_fn(x)
23
+
24
+ def __repr__(self):
25
+ if hasattr(self.expression_fn, "func") and hasattr(
26
+ self.expression_fn, "kwargs"
27
+ ):
28
+ expression_str = "{:s} {:s}".format(
29
+ self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
30
+ )
31
+ elif hasattr(self.expression_fn, "__name__"):
32
+ expression_str = self.expression_fn.__name__
33
+ else:
34
+ expression_str = repr(self.expression_fn)
35
+ return self.__class__.__name__ + "(expression=%s) " % expression_str
36
+
37
+
38
+ class IntermediateOutputWrapper(nn.Module):
39
+ """Wraps network model such that outputs of intermediate layers can be returned.
40
+ forward() returns list of intermediate activations in a network during forward pass.
41
+
42
+ Parameters
43
+ ----------
44
+ to_select : list
45
+ list of module names for which activation should be returned
46
+ model : model object
47
+ network model
48
+
49
+ Examples
50
+ --------
51
+ >>> model = Deep4Net()
52
+ >>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
53
+ >>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
54
+ """
55
+
56
+ def __init__(self, to_select, model):
57
+ if not len(list(model.children())) == len(list(model.named_children())):
58
+ raise Exception("All modules in model need to have names!")
59
+
60
+ super().__init__()
61
+
62
+ modules_list = model.named_children()
63
+ for key, module in modules_list:
64
+ self.add_module(key, module)
65
+ self._modules[key].load_state_dict(module.state_dict())
66
+ self._to_select = to_select
67
+
68
+ def forward(self, x):
69
+ # Call modules individually and append activation to output if module is in to_select
70
+ o = []
71
+ for name, module in self._modules.items():
72
+ x = module(x)
73
+ if name in self._to_select:
74
+ o.append(x)
75
+ return o
@@ -1,12 +1,37 @@
1
- from .preprocess import (exponential_moving_demean,
2
- exponential_moving_standardize, filterbank,
3
- preprocess, Preprocessor)
4
- from .mne_preprocess import (Resample, DropChannels, SetEEGReference, Filter, Pick, Crop)
5
- from .windowers import (create_windows_from_events, create_fixed_length_windows,
6
- create_windows_from_target_channels)
1
+ from .mne_preprocess import ( # type: ignore[attr-defined]
2
+ Crop,
3
+ DropChannels,
4
+ Filter,
5
+ Pick,
6
+ Resample,
7
+ SetEEGReference,
8
+ )
9
+ from .preprocess import (
10
+ Preprocessor,
11
+ exponential_moving_demean,
12
+ exponential_moving_standardize,
13
+ filterbank,
14
+ preprocess,
15
+ )
16
+ from .windowers import (
17
+ create_fixed_length_windows,
18
+ create_windows_from_events,
19
+ create_windows_from_target_channels,
20
+ )
7
21
 
8
- __all__ = ["exponential_moving_demean", "exponential_moving_standardize",
9
- "filterbank", "preprocess", "Preprocessor", "Resample", "DropChannels",
10
- "SetEEGReference", "Filter", "Pick", "Crop",
11
- "create_windows_from_events", "create_fixed_length_windows",
12
- "create_windows_from_target_channels"]
22
+ __all__ = [
23
+ "exponential_moving_demean",
24
+ "exponential_moving_standardize",
25
+ "filterbank",
26
+ "preprocess",
27
+ "Preprocessor",
28
+ "Resample",
29
+ "DropChannels",
30
+ "SetEEGReference",
31
+ "Filter",
32
+ "Pick",
33
+ "Crop",
34
+ "create_windows_from_events",
35
+ "create_fixed_length_windows",
36
+ "create_windows_from_target_channels",
37
+ ]
@@ -1,11 +1,14 @@
1
1
  """Preprocessor objects based on mne methods."""
2
+
2
3
  # Authors: Bruna Lopes <brunajaflopes@gmail.com>
3
4
  # Bruno Aristimunha <b.aristimunha@gmail.com>
4
5
  #
5
6
  # License: BSD-3
6
7
  import inspect
8
+
7
9
  import mne.io
8
- from braindecode.preprocessing import Preprocessor
10
+
11
+ from braindecode.preprocessing.preprocess import Preprocessor
9
12
  from braindecode.util import _update_moabb_docstring
10
13
 
11
14
 
@@ -31,9 +34,9 @@ def _generate_mne_pre_processor(function):
31
34
  """
32
35
  Generate a class based on an MNE function for preprocessing.
33
36
  """
34
- class_name = ''.join(
35
- word.title() for word in function.__name__.split('_')).replace('Eeg',
36
- 'EEG')
37
+ class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
38
+ "Eeg", "EEG"
39
+ )
37
40
  import_path = f"{function.__module__}.{function.__name__}"
38
41
  doc = f" See more details in {import_path}"
39
42
 
@@ -55,7 +58,7 @@ mne_functions = [
55
58
  mne.io.Raw.filter,
56
59
  mne.io.Raw.crop,
57
60
  mne.io.Raw.pick,
58
- mne.io.Raw.set_eeg_reference
61
+ mne.io.Raw.set_eeg_reference,
59
62
  ]
60
63
 
61
64
  # Automatically generate and add classes to the global namespace
@@ -64,8 +67,11 @@ for function in mne_functions:
64
67
  globals()[class_obj.__name__] = class_obj
65
68
 
66
69
  # Define __all__ based on the generated class names
67
- __all__ = [class_obj.__name__ for class_obj in globals().values() if
68
- isinstance(class_obj, type)]
70
+ __all__ = [
71
+ class_obj.__name__
72
+ for class_obj in globals().values()
73
+ if isinstance(class_obj, type)
74
+ ]
69
75
 
70
76
  # Clean up unnecessary variables
71
77
  del mne_functions, function, class_obj