braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from typing import Optional, Sequence
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from mne.utils import warn
9
+ from torch import nn
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.models.fbcnet import _valid_layers
13
+ from braindecode.modules import (
14
+ Conv2dWithConstraint,
15
+ FilterBankLayer,
16
+ LinearWithConstraint,
17
+ )
18
+
19
+
20
+ class FBMSNet(EEGModuleMixin, nn.Module):
21
+ """FBMSNet from Liu et al (2022) [fbmsnet]_.
22
+
23
+ .. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
24
+ :align: center
25
+ :alt: FBMSNet Architecture
26
+
27
+ 0. **FilterBank Layer**: Applying filterbank to transform the input.
28
+
29
+ 1. **Temporal Convolution Block**: Utilizes mixed depthwise convolution
30
+ (MixConv) to extract multiscale temporal features from multiview EEG
31
+ representations. The input is split into groups corresponding to different
32
+ views each convolved with kernels of varying sizes.
33
+ Kernel sizes are set relative to the EEG
34
+ sampling rate, with ratio coefficients [0.5, 0.25, 0.125, 0.0625],
35
+ dividing the input into four groups.
36
+
37
+ 2. **Spatial Convolution Block**: Applies depthwise convolution with a kernel
38
+ size of (n_chans, 1) to span all EEG channels, effectively learning spatial
39
+ filters. This is followed by batch normalization and the Swish activation
40
+ function. A maximum norm constraint of 2 is imposed on the convolution
41
+ weights to regularize the model.
42
+
43
+ 3. **Temporal Log-Variance Block**: Computes the log-variance.
44
+
45
+ 4. **Classification Layer**: A fully connected with weight constraint.
46
+
47
+ Notes
48
+ -----
49
+ This implementation is not guaranteed to be correct and has not been checked
50
+ by the original authors; it has only been reimplemented from the paper
51
+ description and source code [fbmsnetcode]_. There is an extra layer here to
52
+ compute the filterbank during bash time and not on data time. This avoids
53
+ data-leak, and allows the model to follow the braindecode convention.
54
+
55
+ Parameters
56
+ ----------
57
+ n_bands : int, default=9
58
+ Number of input channels (e.g., number of frequency bands).
59
+ stride_factor : int, default=4
60
+ Stride factor for temporal segmentation.
61
+ temporal_layer : str, default='LogVarLayer'
62
+ Temporal aggregation layer to use.
63
+ n_filters_spat : int, default=36
64
+ Number of output channels from the MixedConv2d layer.
65
+ dilatability : int, default=8
66
+ Expansion factor for the spatial convolution block.
67
+ activation : nn.Module, default=nn.SiLU
68
+ Activation function class to apply.
69
+ verbose: bool, default False
70
+ Verbose parameter to create the filter using mne.
71
+
72
+ References
73
+ ----------
74
+ .. [fbmsnet] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
75
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
76
+ EEG-based motor imagery decoding. IEEE Transactions on Biomedical
77
+ Engineering, 70(2), 436-445.
78
+ .. [fbmsnetcode] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
79
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
80
+ EEG-based motor imagery decoding.
81
+ https://github.com/Want2Vanish/FBMSNet
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ # Braindecode parameters
87
+ n_chans=None,
88
+ n_outputs=None,
89
+ chs_info=None,
90
+ n_times=None,
91
+ input_window_seconds=None,
92
+ sfreq=None,
93
+ # models parameters
94
+ n_bands: int = 9,
95
+ n_filters_spat: int = 36,
96
+ temporal_layer: str = "LogVarLayer",
97
+ n_dim: int = 3,
98
+ stride_factor: int = 4,
99
+ dilatability: int = 8,
100
+ activation: nn.Module = nn.SiLU,
101
+ kernels_weights: Sequence[int] = (15, 31, 63, 125),
102
+ cnn_max_norm: float = 2,
103
+ linear_max_norm: float = 0.5,
104
+ verbose: bool = False,
105
+ filter_parameters: Optional[dict] = None,
106
+ ):
107
+ super().__init__(
108
+ n_chans=n_chans,
109
+ n_outputs=n_outputs,
110
+ chs_info=chs_info,
111
+ n_times=n_times,
112
+ input_window_seconds=input_window_seconds,
113
+ sfreq=sfreq,
114
+ )
115
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
116
+
117
+ # Parameters
118
+ self.n_bands = n_bands
119
+ self.n_filters_spat = n_filters_spat
120
+ self.n_dim = n_dim
121
+ self.stride_factor = stride_factor
122
+ self.activation = activation
123
+ self.dilatability = dilatability
124
+ self.kernels_weights = kernels_weights
125
+ self.filter_parameters = filter_parameters or {}
126
+ self.out_channels_spatial = self.n_filters_spat * self.dilatability
127
+
128
+ # Checkers
129
+ if temporal_layer not in _valid_layers:
130
+ raise NotImplementedError(
131
+ f"Temporal layer '{temporal_layer}' is not implemented."
132
+ )
133
+
134
+ if self.n_times % self.stride_factor != 0:
135
+ warn(
136
+ f"Time dimension ({self.n_times}) is not divisible by"
137
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
138
+ UserWarning,
139
+ )
140
+
141
+ # Layers
142
+ # Following paper nomeclature
143
+ self.spectral_filtering = FilterBankLayer(
144
+ n_chans=self.n_chans,
145
+ sfreq=self.sfreq,
146
+ band_filters=self.n_bands,
147
+ verbose=verbose,
148
+ **self.filter_parameters,
149
+ )
150
+ # As we have an internal process to create the bands,
151
+ # we get the values from the filterbank
152
+ self.n_bands = self.spectral_filtering.n_bands
153
+
154
+ # MixedConv2d Layer
155
+ self.mix_conv = nn.Sequential(
156
+ _MixedConv2d(
157
+ in_channels=self.n_bands,
158
+ out_channels=self.n_filters_spat,
159
+ stride=1,
160
+ dilation=1,
161
+ depthwise=False,
162
+ kernels_weights=kernels_weights,
163
+ ),
164
+ nn.BatchNorm2d(self.n_filters_spat),
165
+ )
166
+
167
+ # Spatial Convolution Block (SCB)
168
+ self.spatial_conv = nn.Sequential(
169
+ Conv2dWithConstraint(
170
+ in_channels=self.n_filters_spat,
171
+ out_channels=self.out_channels_spatial,
172
+ kernel_size=(self.n_chans, 1),
173
+ groups=self.n_filters_spat,
174
+ max_norm=cnn_max_norm,
175
+ padding=0,
176
+ ),
177
+ nn.BatchNorm2d(self.out_channels_spatial),
178
+ self.activation(),
179
+ )
180
+
181
+ # Padding layer
182
+ if self.n_times % self.stride_factor != 0:
183
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
184
+ self.n_times_padded = self.n_times + self.padding_size
185
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
186
+ else:
187
+ self.padding_layer = nn.Identity()
188
+ self.n_times_padded = self.n_times
189
+
190
+ # Temporal Aggregation Layer
191
+ self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
192
+
193
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
194
+
195
+ # Final fully connected layer
196
+ self.final_layer = LinearWithConstraint(
197
+ in_features=self.out_channels_spatial * self.stride_factor,
198
+ out_features=self.n_outputs,
199
+ max_norm=linear_max_norm,
200
+ )
201
+
202
+ def forward(self, x):
203
+ """
204
+ Forward pass of the FBMSNet model.
205
+
206
+ Parameters
207
+ ----------
208
+ x : torch.Tensor
209
+ Input tensor with shape (batch_size, n_chans, n_times).
210
+
211
+ Returns
212
+ -------
213
+ torch.Tensor
214
+ Output tensor with shape (batch_size, n_outputs).
215
+ """
216
+ batch, _, _ = x.shape
217
+
218
+ # shape: (batch, n_chans, n_times)
219
+ x = self.spectral_filtering(x)
220
+ # shape: (batch, n_bands, n_chans, n_times)
221
+
222
+ # Mixed convolution
223
+ x = self.mix_conv(x)
224
+ # shape: (batch, self.n_filters_spat, n_chans, n_times)
225
+
226
+ # Spatial convolution block
227
+ x = self.spatial_conv(x)
228
+ # shape: (batch, self.out_channels_spatial, 1, n_times)
229
+
230
+ # Apply some padding to the input to make it divisible by the stride factor
231
+ x = self.padding_layer(x)
232
+ # shape: (batch, self.out_channels_spatial, 1, n_times_padded)
233
+
234
+ # Reshape for temporal layer
235
+ x = x.view(batch, self.out_channels_spatial, self.stride_factor, -1)
236
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, n_times/self.stride_factor)
237
+
238
+ # Temporal aggregation
239
+ x = self.temporal_layer(x)
240
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, 1)
241
+
242
+ # Flatten and classify
243
+ x = self.flatten_layer(x)
244
+ # shape: (batch, self.out_channels_spatial*self.stride_factor)
245
+
246
+ x = self.final_layer(x)
247
+ # shape: (batch, n_outputs)
248
+ return x
249
+
250
+
251
+ class _MixedConv2d(nn.Module):
252
+ """Mixed Grouped Convolution for multiscale feature extraction."""
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels,
257
+ out_channels,
258
+ kernels_weights=(15, 31, 63, 125),
259
+ stride=1,
260
+ dilation=1,
261
+ depthwise=False,
262
+ ):
263
+ super().__init__()
264
+
265
+ num_groups = len(kernels_weights)
266
+ in_splits = self._split_channels(in_channels, num_groups)
267
+ out_splits = self._split_channels(out_channels, num_groups)
268
+ self.splits = in_splits
269
+
270
+ self.convs = nn.ModuleList()
271
+ # Create a convolutional layer for each kernel size
272
+ for k, in_ch, out_ch in zip(kernels_weights, in_splits, out_splits):
273
+ conv_groups = out_ch if depthwise else 1
274
+ conv = nn.Conv2d(
275
+ in_channels=in_ch,
276
+ out_channels=out_ch,
277
+ kernel_size=(1, k),
278
+ stride=stride,
279
+ padding="same",
280
+ dilation=dilation,
281
+ groups=conv_groups,
282
+ bias=False,
283
+ )
284
+ self.convs.append(conv)
285
+
286
+ @staticmethod
287
+ def _split_channels(num_chan, num_groups):
288
+ """
289
+ Splits the total number of channels into a specified
290
+ number of groups as evenly as possible.
291
+
292
+ Parameters
293
+ ----------
294
+ num_chan : int
295
+ The total number of channels to split.
296
+ num_groups : int
297
+ The number of groups to split the channels into.
298
+
299
+ Returns
300
+ -------
301
+ list of int
302
+ A list containing the number of channels in each group.
303
+ The first group may have more channels if the division is not even.
304
+ """
305
+ split = [num_chan // num_groups for _ in range(num_groups)]
306
+ split[0] += num_chan - sum(split)
307
+ return split
308
+
309
+ def forward(self, x):
310
+ # Split the input tensor `x` along the channel dimension (dim=1) into groups.
311
+ # The size of each group is defined by `self.splits`, which is calculated
312
+ # based on the number of input channels and the number of kernel sizes.
313
+ x_split = torch.split(x, self.splits, 1)
314
+
315
+ # For each split group, apply the corresponding convolutional layer.
316
+ # `self.values()` returns the convolutional layers in the order they were added.
317
+ # The result is a list of output tensors, one for each group.
318
+ x_out = [conv(x_split[i]) for i, conv in enumerate(self.convs)]
319
+
320
+ # Concatenate the outputs from all groups along the channel dimension (dim=1)
321
+ # to form a single output tensor.
322
+ x = torch.cat(x_out, 1)
323
+
324
+ # Return the concatenated tensor as the output of the mixed convolution.
325
+ return x
@@ -0,0 +1,126 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import ConstantPad2d
8
+
9
+ from braindecode.models.deep4 import Deep4Net
10
+ from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
11
+
12
+
13
+ class HybridNet(nn.Module):
14
+ """Hybrid ConvNet model from Schirrmeister, R T et al (2017) [Schirrmeister2017]_.
15
+
16
+ See [Schirrmeister2017]_ for details.
17
+
18
+ References
19
+ ----------
20
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
21
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
22
+ & Ball, T. (2017).
23
+ Deep learning with convolutional neural networks for EEG decoding and
24
+ visualization.
25
+ Human Brain Mapping , Aug. 2017.
26
+ Online: http://dx.doi.org/10.1002/hbm.23730
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ n_chans=None,
32
+ n_outputs=None,
33
+ n_times=None,
34
+ input_window_seconds=None,
35
+ sfreq=None,
36
+ chs_info=None,
37
+ activation: nn.Module = nn.ELU,
38
+ drop_prob: float = 0.5,
39
+ ):
40
+ super().__init__()
41
+ self.mapping = {
42
+ "final_conv.weight": "final_layer.weight",
43
+ "final_conv.bias": "final_layer.bias",
44
+ }
45
+
46
+ deep_model = Deep4Net(
47
+ n_chans=n_chans,
48
+ n_outputs=n_outputs,
49
+ n_filters_time=20,
50
+ n_filters_spat=30,
51
+ n_filters_2=40,
52
+ n_filters_3=50,
53
+ n_filters_4=60,
54
+ n_times=n_times,
55
+ input_window_seconds=input_window_seconds,
56
+ sfreq=sfreq,
57
+ chs_info=chs_info,
58
+ final_conv_length=2,
59
+ activation_first_conv_nonlin=activation,
60
+ activation_later_conv_nonlin=activation,
61
+ drop_prob=drop_prob,
62
+ )
63
+ shallow_model = ShallowFBCSPNet(
64
+ n_chans=n_chans,
65
+ n_outputs=n_outputs,
66
+ n_times=n_times,
67
+ input_window_seconds=input_window_seconds,
68
+ sfreq=sfreq,
69
+ chs_info=chs_info,
70
+ n_filters_time=30,
71
+ n_filters_spat=40,
72
+ filter_time_length=28,
73
+ final_conv_length=29,
74
+ drop_prob=drop_prob,
75
+ )
76
+ new_conv_layer = nn.Conv2d(
77
+ deep_model.final_layer.conv_classifier.in_channels,
78
+ 60,
79
+ kernel_size=deep_model.final_layer.conv_classifier.kernel_size,
80
+ stride=deep_model.final_layer.conv_classifier.stride,
81
+ )
82
+ deep_model.final_layer = new_conv_layer
83
+
84
+ new_conv_layer = nn.Conv2d(
85
+ shallow_model.final_layer.conv_classifier.in_channels,
86
+ 40,
87
+ kernel_size=shallow_model.final_layer.conv_classifier.kernel_size,
88
+ stride=shallow_model.final_layer.conv_classifier.stride,
89
+ )
90
+ shallow_model.final_layer = new_conv_layer
91
+
92
+ deep_model.to_dense_prediction_model()
93
+ shallow_model.to_dense_prediction_model()
94
+ self.reduced_deep_model = deep_model
95
+ self.reduced_shallow_model = shallow_model
96
+
97
+ self.final_layer = nn.Sequential(
98
+ nn.Conv2d(100, n_outputs, kernel_size=(1, 1), stride=1),
99
+ nn.Identity(),
100
+ )
101
+
102
+ def forward(self, x):
103
+ """Forward pass.
104
+
105
+ Parameters
106
+ ----------
107
+ x: torch.Tensor
108
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
109
+ """
110
+ deep_out = self.reduced_deep_model(x)
111
+ shallow_out = self.reduced_shallow_model(x)
112
+
113
+ n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
114
+
115
+ if n_diff_deep_shallow < 0:
116
+ deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(deep_out)
117
+ elif n_diff_deep_shallow > 0:
118
+ shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(shallow_out)
119
+
120
+ merged_out = torch.cat((deep_out, shallow_out), dim=1)
121
+
122
+ output = self.final_layer(merged_out)
123
+
124
+ squeezed = output.squeeze(3)
125
+
126
+ return squeezed