braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class InceptionBlock(nn.Module):
6
+ """
7
+ Inception block module.
8
+
9
+ This module applies multiple convolutional branches to the input and concatenates
10
+ their outputs along the channel dimension. Each branch can have a different
11
+ configuration, allowing the model to capture multi-scale features.
12
+
13
+ Parameters
14
+ ----------
15
+ branches : list of nn.Module
16
+ List of convolutional branches to apply to the input.
17
+ """
18
+
19
+ def __init__(self, branches):
20
+ super().__init__()
21
+ self.branches = nn.ModuleList(branches)
22
+
23
+ def forward(self, x):
24
+ return torch.cat([branch(x) for branch in self.branches], 1)
25
+
26
+
27
+ class MLP(nn.Sequential):
28
+ r"""Multilayer Perceptron (MLP) with GELU activation and optional dropout.
29
+
30
+ Also known as fully connected feedforward network, an MLP is a sequence of
31
+ non-linear parametric functions
32
+
33
+ .. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}),
34
+
35
+ over feature vectors :math:`h_i`, with the input and output feature vectors
36
+ :math:`x = h_0` and :math:`y = h_L`, respectively. The non-linear functions
37
+ :math:`a_i` are called activation functions. The trainable parameters of an
38
+ MLP are its weights and biases :math:`\\phi = \{W_i, b_i | i = 1, \dots, L\}`.
39
+
40
+ Parameters:
41
+ -----------
42
+ in_features: int
43
+ Number of input features.
44
+ hidden_features: Sequential[int] (default=None)
45
+ Number of hidden features, if None, set to in_features.
46
+ You can increase the size of MLP just passing more int in the
47
+ hidden features vector. The model size increase follow the
48
+ rule 2n (hidden layers)+2 (in and out layers)
49
+ out_features: int (default=None)
50
+ Number of output features, if None, set to in_features.
51
+ act_layer: nn.GELU (default)
52
+ The activation function constructor. If :py:`None`, use
53
+ :class:`torch.nn.GELU` instead.
54
+ drop: float (default=0.0)
55
+ Dropout rate.
56
+ normalize: bool (default=False)
57
+ Whether to apply layer normalization.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_features: int,
63
+ hidden_features=None,
64
+ out_features=None,
65
+ activation=nn.GELU,
66
+ drop=0.0,
67
+ normalize=False,
68
+ ):
69
+ self.normalization = nn.LayerNorm if normalize else lambda: None
70
+ self.in_features = in_features
71
+ self.out_features = out_features or self.in_features
72
+ if hidden_features:
73
+ self.hidden_features = hidden_features
74
+ else:
75
+ self.hidden_features = (self.in_features, self.in_features)
76
+ self.activation = activation
77
+
78
+ layers = []
79
+
80
+ for before, after in zip(
81
+ (self.in_features, *self.hidden_features),
82
+ (*self.hidden_features, self.out_features),
83
+ ):
84
+ layers.extend(
85
+ [
86
+ nn.Linear(in_features=before, out_features=after),
87
+ self.activation(),
88
+ self.normalization(),
89
+ ]
90
+ )
91
+
92
+ layers = layers[:-2]
93
+ layers.append(nn.Dropout(p=drop))
94
+
95
+ # Cleaning if we are not using the normalization layer
96
+ layers = list(filter(lambda layer: layer is not None, layers))
97
+
98
+ super().__init__(*layers)
99
+
100
+
101
+ class FeedForwardBlock(nn.Sequential):
102
+ def __init__(self, emb_size, expansion, drop_p, activation: nn.Module = nn.GELU):
103
+ super().__init__(
104
+ nn.Linear(emb_size, expansion * emb_size),
105
+ activation(),
106
+ nn.Dropout(drop_p),
107
+ nn.Linear(expansion * emb_size, emb_size),
108
+ )
@@ -0,0 +1,274 @@
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils.parametrize import register_parametrization
8
+
9
+ from braindecode.util import np_to_th
10
+
11
+ from .parametrization import MaxNormParametrize
12
+
13
+
14
+ class AvgPool2dWithConv(nn.Module):
15
+ """
16
+ Compute average pooling using a convolution, to have the dilation parameter.
17
+
18
+ Parameters
19
+ ----------
20
+ kernel_size: (int,int)
21
+ Size of the pooling region.
22
+ stride: (int,int)
23
+ Stride of the pooling operation.
24
+ dilation: int or (int,int)
25
+ Dilation applied to the pooling filter.
26
+ padding: int or (int,int)
27
+ Padding applied before the pooling operation.
28
+ """
29
+
30
+ def __init__(self, kernel_size, stride, dilation=1, padding=0):
31
+ super().__init__()
32
+ self.kernel_size = kernel_size
33
+ self.stride = stride
34
+ self.dilation = dilation
35
+ self.padding = padding
36
+ # don't name them "weights" to
37
+ # make sure these are not accidentally used by some procedure
38
+ # that initializes parameters or something
39
+ self._pool_weights = None
40
+
41
+ def forward(self, x):
42
+ # Create weights for the convolution on demand:
43
+ # size or type of x changed...
44
+ in_channels = x.size()[1]
45
+ weight_shape = (
46
+ in_channels,
47
+ 1,
48
+ self.kernel_size[0],
49
+ self.kernel_size[1],
50
+ )
51
+ if self._pool_weights is None or (
52
+ (tuple(self._pool_weights.size()) != tuple(weight_shape))
53
+ or (self._pool_weights.is_cuda != x.is_cuda)
54
+ or (self._pool_weights.data.type() != x.data.type())
55
+ ):
56
+ n_pool = np.prod(self.kernel_size)
57
+ weights = np_to_th(np.ones(weight_shape, dtype=np.float32) / float(n_pool))
58
+ weights = weights.type_as(x)
59
+ if x.is_cuda:
60
+ weights = weights.cuda()
61
+ self._pool_weights = weights
62
+
63
+ pooled = F.conv2d(
64
+ x,
65
+ self._pool_weights,
66
+ bias=None,
67
+ stride=self.stride,
68
+ dilation=self.dilation,
69
+ padding=self.padding,
70
+ groups=in_channels,
71
+ )
72
+ return pooled
73
+
74
+
75
+ class Conv2dWithConstraint(nn.Conv2d):
76
+ def __init__(self, *args, max_norm=1, **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.max_norm = max_norm
79
+ # initialize the weights
80
+ nn.init.xavier_uniform_(self.weight, gain=1)
81
+ register_parametrization(self, "weight", MaxNormParametrize(self.max_norm))
82
+
83
+
84
+ class CombinedConv(nn.Module):
85
+ """Merged convolutional layer for temporal and spatial convs in Deep4/ShallowFBCSP
86
+
87
+ Numerically equivalent to the separate sequential approach, but this should be faster.
88
+
89
+ Parameters
90
+ ----------
91
+ in_chans : int
92
+ Number of EEG input channels.
93
+ n_filters_time: int
94
+ Number of temporal filters.
95
+ filter_time_length: int
96
+ Length of the temporal filter.
97
+ n_filters_spat: int
98
+ Number of spatial filters.
99
+ bias_time: bool
100
+ Whether to use bias in the temporal conv
101
+ bias_spat: bool
102
+ Whether to use bias in the spatial conv
103
+
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ in_chans,
109
+ n_filters_time=40,
110
+ n_filters_spat=40,
111
+ filter_time_length=25,
112
+ bias_time=True,
113
+ bias_spat=True,
114
+ ):
115
+ super().__init__()
116
+ self.bias_time = bias_time
117
+ self.bias_spat = bias_spat
118
+ self.conv_time = nn.Conv2d(
119
+ 1, n_filters_time, (filter_time_length, 1), bias=bias_time, stride=1
120
+ )
121
+ self.conv_spat = nn.Conv2d(
122
+ n_filters_time, n_filters_spat, (1, in_chans), bias=bias_spat, stride=1
123
+ )
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ # Merge time and spat weights
127
+ combined_weight = (
128
+ (self.conv_time.weight * self.conv_spat.weight.permute(1, 0, 2, 3))
129
+ .sum(0)
130
+ .unsqueeze(1)
131
+ )
132
+
133
+ bias = None
134
+ calculated_bias: Optional[torch.Tensor] = None
135
+
136
+ # Calculate bias terms
137
+ if self.bias_time:
138
+ time_bias = self.conv_time.bias
139
+ assert time_bias is not None
140
+ calculated_bias = (
141
+ self.conv_spat.weight.squeeze()
142
+ .sum(-1)
143
+ .mm(time_bias.unsqueeze(-1))
144
+ .squeeze()
145
+ )
146
+ if self.bias_spat:
147
+ spat_bias = self.conv_spat.bias
148
+ assert spat_bias is not None
149
+ if calculated_bias is None:
150
+ calculated_bias = spat_bias
151
+ else:
152
+ calculated_bias = calculated_bias + spat_bias
153
+
154
+ bias = calculated_bias
155
+
156
+ return F.conv2d(x, weight=combined_weight, bias=bias, stride=(1, 1))
157
+
158
+
159
+ class CausalConv1d(nn.Conv1d):
160
+ """Causal 1-dimensional convolution
161
+
162
+ Code modified from [1]_ and [2]_.
163
+
164
+ Parameters
165
+ ----------
166
+ in_channels : int
167
+ Input channels.
168
+ out_channels : int
169
+ Output channels (number of filters).
170
+ kernel_size : int
171
+ Kernel size.
172
+ dilation : int, optional
173
+ Dilation (number of elements to skip within kernel multiplication).
174
+ Default to 1.
175
+ **kwargs :
176
+ Other keyword arguments to pass to torch.nn.Conv1d, except for
177
+ `padding`!!
178
+
179
+ References
180
+ ----------
181
+ .. [1] https://discuss.pytorch.org/t/causal-convolution/3456/4
182
+ .. [2] https://gist.github.com/paultsw/7a9d6e3ce7b70e9e2c61bc9287addefc
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ in_channels,
188
+ out_channels,
189
+ kernel_size,
190
+ dilation=1,
191
+ **kwargs,
192
+ ):
193
+ assert "padding" not in kwargs, (
194
+ "The padding parameter is controlled internally by "
195
+ f"{type(self).__name__} class. You should not try to override this"
196
+ " parameter."
197
+ )
198
+
199
+ super().__init__(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=kernel_size,
203
+ dilation=dilation,
204
+ padding=(kernel_size - 1) * dilation,
205
+ **kwargs,
206
+ )
207
+
208
+ def forward(self, X):
209
+ out = F.conv1d(
210
+ X,
211
+ self.weight,
212
+ self.bias,
213
+ stride=self.stride,
214
+ padding=self.padding,
215
+ dilation=self.dilation,
216
+ groups=self.groups,
217
+ )
218
+ return out[..., : -self.padding[0]]
219
+
220
+
221
+ class DepthwiseConv2d(torch.nn.Conv2d):
222
+ """
223
+ Depthwise convolution layer.
224
+
225
+ This class implements a depthwise convolution, where each input channel is
226
+ convolved separately with its own filter (channel multiplier), effectively
227
+ performing a spatial convolution independently over each channel.
228
+
229
+ Parameters
230
+ ----------
231
+ in_channels : int
232
+ Number of channels in the input tensor.
233
+ depth_multiplier : int, optional
234
+ Multiplier for the number of output channels. The total number of
235
+ output channels will be `in_channels * depth_multiplier`. Default is 2.
236
+ kernel_size : int or tuple, optional
237
+ Size of the convolutional kernel. Default is 3.
238
+ stride : int or tuple, optional
239
+ Stride of the convolution. Default is 1.
240
+ padding : int or tuple, optional
241
+ Padding added to both sides of the input. Default is 0.
242
+ dilation : int or tuple, optional
243
+ Spacing between kernel elements. Default is 1.
244
+ bias : bool, optional
245
+ If True, adds a learnable bias to the output. Default is True.
246
+ padding_mode : str, optional
247
+ Padding mode to use. Options are 'zeros', 'reflect', 'replicate', or
248
+ 'circular'.
249
+ Default is 'zeros'.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ in_channels,
255
+ depth_multiplier=2,
256
+ kernel_size=3,
257
+ stride=1,
258
+ padding=0,
259
+ dilation=1,
260
+ bias=True,
261
+ padding_mode="zeros",
262
+ ):
263
+ out_channels = in_channels * depth_multiplier
264
+ super().__init__(
265
+ in_channels=in_channels,
266
+ out_channels=out_channels,
267
+ kernel_size=kernel_size,
268
+ stride=stride,
269
+ padding=padding,
270
+ dilation=dilation,
271
+ groups=in_channels,
272
+ bias=bias,
273
+ padding_mode=padding_mode,
274
+ )