braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,317 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from mne.utils import warn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class ContraWR(EEGModuleMixin, nn.Module):
11
+ """Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
12
+
13
+ This model is a convolutional neural network that uses a spectral
14
+ representation with a series of convolutional layers and residual blocks.
15
+ The model is designed to learn a representation of the EEG signal that can
16
+ be used for sleep staging.
17
+
18
+ Parameters
19
+ ----------
20
+ steps : int, optional
21
+ Number of steps to take the frequency decomposition `hop_length`
22
+ parameters by default 20.
23
+ emb_size : int, optional
24
+ Embedding size for the final layer, by default 256.
25
+ res_channels : list[int], optional
26
+ Number of channels for each residual block, by default [32, 64, 128].
27
+ activation: nn.Module, default=nn.ELU
28
+ Activation function class to apply. Should be a PyTorch activation
29
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
30
+ drop_prob : float, default=0.5
31
+ The dropout rate for regularization. Values should be between 0 and 1.
32
+
33
+ .. versionadded:: 0.9
34
+
35
+ Notes
36
+ -----
37
+ This implementation is not guaranteed to be correct, has not been checked
38
+ by original authors. The modifications are minimal and the model is expected
39
+ to work as intended. the original code from [Code2023]_.
40
+
41
+ References
42
+ ----------
43
+ .. [Yang2021] Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023).
44
+ Self-supervised electroencephalogram representation learning for automatic
45
+ sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.
46
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
47
+ Biosignal Transformer for Cross-data Learning in the Wild.
48
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ n_chans=None,
54
+ n_outputs=None,
55
+ sfreq=None,
56
+ emb_size: int = 256,
57
+ res_channels: list[int] = [32, 64, 128],
58
+ steps=20,
59
+ activation: nn.Module = nn.ELU,
60
+ drop_prob: float = 0.5,
61
+ stride_res: int = 2,
62
+ kernel_size_res: int = 3,
63
+ padding_res: int = 1,
64
+ # Another way to pass the EEG parameters
65
+ chs_info=None,
66
+ n_times=None,
67
+ input_window_seconds=None,
68
+ ):
69
+ super().__init__(
70
+ n_outputs=n_outputs,
71
+ n_chans=n_chans,
72
+ chs_info=chs_info,
73
+ n_times=n_times,
74
+ input_window_seconds=input_window_seconds,
75
+ sfreq=sfreq,
76
+ )
77
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
78
+ if not isinstance(res_channels, list):
79
+ raise ValueError("res_channels must be a list of integers.")
80
+
81
+ if self.input_window_seconds < 1.0:
82
+ warning_msg = (
83
+ "The input window is less than 1 second, which may not be "
84
+ "sufficient for the model to learn meaningful representations."
85
+ "changing the `n_fft` to `n_times`."
86
+ )
87
+ warn(warning_msg, UserWarning)
88
+ self.n_fft = self.n_times
89
+ else:
90
+ self.n_fft = int(self.sfreq)
91
+
92
+ self.steps = steps
93
+
94
+ res_channels = [self.n_chans] + res_channels + [emb_size]
95
+
96
+ self.torch_stft = _STFTModule(
97
+ n_fft=self.n_fft,
98
+ hop_length=int(self.n_fft // self.steps),
99
+ )
100
+
101
+ self.convs = nn.ModuleList(
102
+ [
103
+ _ResBlock(
104
+ in_channels=res_channels[i],
105
+ out_channels=res_channels[i + 1],
106
+ stride=stride_res,
107
+ use_downsampling=True,
108
+ pooling=True,
109
+ drop_prob=drop_prob,
110
+ kernel_size=kernel_size_res,
111
+ padding=padding_res,
112
+ activation=activation,
113
+ )
114
+ for i in range(len(res_channels) - 1)
115
+ ]
116
+ )
117
+ self.adaptative_pool = nn.AdaptiveAvgPool2d((1, 1))
118
+ self.flatten_layer = nn.Flatten()
119
+
120
+ self.activation_layer = activation()
121
+ self.final_layer = nn.Linear(emb_size, self.n_outputs)
122
+
123
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Forward pass.
126
+
127
+ Parameters
128
+ ----------
129
+ X: Tensor
130
+ Input tensor of shape (batch_size, n_channels, n_times).
131
+ Returns
132
+ -------
133
+ Tensor
134
+ Output tensor of shape (batch_size, n_outputs).
135
+ """
136
+ X = self.torch_stft(X)
137
+
138
+ for conv in self.convs:
139
+ X = conv.forward(X)
140
+
141
+ emb = self.adaptative_pool(X)
142
+ emb = self.flatten_layer(emb)
143
+ emb = self.activation_layer(emb)
144
+
145
+ return self.final_layer(emb)
146
+
147
+
148
+ class _ResBlock(nn.Module):
149
+ """Convolutional Residual Block 2D.
150
+
151
+ This block stacks two convolutional layers with batch normalization,
152
+ max pooling, dropout, and residual connection.
153
+
154
+ Parameters
155
+ ----------
156
+ in_channels : int
157
+ Number of input channels.
158
+ out_channels : int
159
+ Number of output channels.
160
+ stride : int (default=1)
161
+ Stride of the convolutional layers.
162
+ use_downsampling : bool (default=True)
163
+ Whether to use a downsampling residual connection.
164
+ pooling : bool (default=True)
165
+ Whether to use max pooling.
166
+ kernel_size : int (default=3)
167
+ Kernel size of the convolutional layers.
168
+ padding : int (default=1)
169
+ Padding of the convolutional layers.
170
+ activation: nn.Module, default=nn.ELU
171
+ Activation function class to apply. Should be a PyTorch activation
172
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
173
+ drop_prob : float, default=0.5
174
+ The dropout rate for regularization. Values should be between 0 and 1.
175
+
176
+ Examples
177
+ --------
178
+ >>> import torch
179
+ >>> model = ResBlock2D(6, 16, 1, True, True)
180
+ >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width)
181
+ >>> output = model(input_)
182
+ >>> output.shape
183
+ torch.Size([16, 16, 14, 75])
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ in_channels,
189
+ out_channels,
190
+ stride=1,
191
+ use_downsampling=True,
192
+ pooling=True,
193
+ kernel_size=3,
194
+ padding=1,
195
+ drop_prob=0.5,
196
+ activation: nn.Module = nn.ReLU,
197
+ ):
198
+ super().__init__()
199
+ self.conv1 = nn.Conv2d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=kernel_size,
203
+ stride=stride,
204
+ padding=padding,
205
+ )
206
+ self.bn1 = nn.BatchNorm2d(out_channels)
207
+ self.relu = activation()
208
+ self.conv2 = nn.Conv2d(
209
+ in_channels=out_channels,
210
+ out_channels=out_channels,
211
+ kernel_size=kernel_size,
212
+ padding=padding,
213
+ )
214
+ self.bn2 = nn.BatchNorm2d(out_channels)
215
+ self.maxpool = nn.MaxPool2d(
216
+ kernel_size=kernel_size, stride=stride, padding=padding
217
+ )
218
+ self.downsample = nn.Sequential(
219
+ nn.Conv2d(
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ kernel_size=kernel_size,
223
+ stride=stride,
224
+ padding=padding,
225
+ ),
226
+ nn.BatchNorm2d(out_channels),
227
+ )
228
+ self.use_downsampling = use_downsampling
229
+ self.pooling = pooling
230
+ self.dropout = nn.Dropout(drop_prob)
231
+
232
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
233
+ """
234
+
235
+ Parameters
236
+ ----------
237
+ X: Tensor
238
+ Input tensor of shape (batch_size, n_channels, n_freqs, n_times).
239
+
240
+ Returns
241
+ -------
242
+ Tensor
243
+ Output tensor of shape (batch_size, n_channels, n_freqs, n_times).
244
+ """
245
+ out = self.conv1(x)
246
+ out = self.bn1(out)
247
+ out = self.relu(out)
248
+ out = self.conv2(out)
249
+ out = self.bn2(out)
250
+ if self.use_downsampling:
251
+ residual = self.downsample(x)
252
+ out += residual
253
+ if self.pooling:
254
+ out = self.maxpool(out)
255
+ out = self.dropout(out)
256
+ return out
257
+
258
+
259
+ class _STFTModule(nn.Module):
260
+ """
261
+ A PyTorch module that computes the Short-Time Fourier Transform (STFT)
262
+ of an EEG batch tensor.
263
+
264
+ Expects input of shape (batch_size, n_channels, n_times) and returns
265
+ (batch_size, n_channels, n_freqs, n_times).
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ n_fft: int,
271
+ hop_length: int,
272
+ center: bool = True,
273
+ onesided: bool = True,
274
+ return_complex: bool = True,
275
+ normalized: bool = True,
276
+ ):
277
+ """
278
+ Parameters
279
+ ----------
280
+ n_fft : int
281
+ Number of FFT points (window size).
282
+ steps : int
283
+ Number of hops per window (i.e. hop_length = n_fft // steps).
284
+ """
285
+ super().__init__()
286
+ self.n_fft = n_fft
287
+ self.hop_length = hop_length
288
+ self.center = center
289
+ self.one_sided = onesided
290
+ self.return_complex = return_complex
291
+ self.normalized = normalized
292
+
293
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
294
+ window = torch.ones(self.n_fft, device=x.device)
295
+
296
+ # x: (B, C, T)
297
+ B, C, T = x.shape
298
+ # flatten batch & channel into one dim
299
+ x_flat = x.reshape(B * C, T)
300
+
301
+ # compute stft on 2D tensor
302
+ spec_flat = torch.stft(
303
+ x_flat,
304
+ n_fft=self.n_fft,
305
+ hop_length=self.hop_length,
306
+ win_length=self.n_fft,
307
+ window=window,
308
+ normalized=self.normalized,
309
+ center=self.center,
310
+ onesided=self.one_sided,
311
+ return_complex=self.return_complex,
312
+ )
313
+
314
+ F, L = spec_flat.shape[-2], spec_flat.shape[-1]
315
+ spec = spec_flat.view(B, C, F, L)
316
+
317
+ return torch.abs(spec)