braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from braindecode.models.base import EEGModuleMixin
7
+
8
+
9
+ class ContraWR(EEGModuleMixin, nn.Module):
10
+ """Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
11
+
12
+ This model is a convolutional neural network that uses a spectral
13
+ representation with a series of convolutional layers and residual blocks.
14
+ The model is designed to learn a representation of the EEG signal that can
15
+ be used for sleep staging.
16
+
17
+ Parameters
18
+ ----------
19
+ steps : int, optional
20
+ Number of steps to take the frequency decomposition `hop_length`
21
+ parameters by default 20.
22
+ emb_size : int, optional
23
+ Embedding size for the final layer, by default 256.
24
+ res_channels : list[int], optional
25
+ Number of channels for each residual block, by default [32, 64, 128].
26
+ activation: nn.Module, default=nn.ELU
27
+ Activation function class to apply. Should be a PyTorch activation
28
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
29
+ drop_prob : float, default=0.5
30
+ The dropout rate for regularization. Values should be between 0 and 1.
31
+
32
+ .. versionadded:: 0.9
33
+
34
+ Notes
35
+ -----
36
+ This implementation is not guaranteed to be correct, has not been checked
37
+ by original authors. The modifications are minimal and the model is expected
38
+ to work as intended. the original code from [Code2023]_.
39
+
40
+ References
41
+ ----------
42
+ .. [Yang2021] Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023).
43
+ Self-supervised electroencephalogram representation learning for automatic
44
+ sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.
45
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
46
+ Biosignal Transformer for Cross-data Learning in the Wild.
47
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ n_chans=None,
53
+ n_outputs=None,
54
+ sfreq=None,
55
+ emb_size: int = 256,
56
+ res_channels: list[int] = [32, 64, 128],
57
+ steps=20,
58
+ activation: nn.Module = nn.ELU,
59
+ drop_prob: float = 0.5,
60
+ # Another way to pass the EEG parameters
61
+ chs_info=None,
62
+ n_times=None,
63
+ input_window_seconds=None,
64
+ ):
65
+ super().__init__(
66
+ n_outputs=n_outputs,
67
+ n_chans=n_chans,
68
+ chs_info=chs_info,
69
+ n_times=n_times,
70
+ input_window_seconds=input_window_seconds,
71
+ sfreq=sfreq,
72
+ )
73
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
74
+ if not isinstance(res_channels, list):
75
+ raise ValueError("res_channels must be a list of integers.")
76
+
77
+ self.n_fft = int(self.sfreq)
78
+ self.steps = steps
79
+
80
+ res_channels = [self.n_chans] + res_channels + [emb_size]
81
+
82
+ self.torch_stft = _STFTModule(
83
+ n_fft=self.n_fft,
84
+ hop_length=int(self.n_fft // self.steps),
85
+ )
86
+
87
+ self.convs = nn.ModuleList(
88
+ [
89
+ _ResBlock(
90
+ in_channels=res_channels[i],
91
+ out_channels=res_channels[i + 1],
92
+ stride=2,
93
+ use_downsampling=True,
94
+ pooling=True,
95
+ drop_prob=drop_prob,
96
+ )
97
+ for i in range(len(res_channels) - 1)
98
+ ]
99
+ )
100
+
101
+ self.final_layer = nn.Sequential(
102
+ activation(),
103
+ nn.Linear(emb_size, self.n_outputs),
104
+ )
105
+
106
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Forward pass.
109
+
110
+ Parameters
111
+ ----------
112
+ X: Tensor
113
+ Input tensor of shape (batch_size, n_channels, n_times).
114
+ Returns
115
+ -------
116
+ Tensor
117
+ Output tensor of shape (batch_size, n_outputs).
118
+ """
119
+ X = self.torch_stft(X)
120
+
121
+ for conv in self.convs[:-1]:
122
+ X = conv.forward(X)
123
+ emb = self.convs[-1](X).squeeze(-1).squeeze(-1)
124
+ return self.final_layer(emb)
125
+
126
+
127
+ class _ResBlock(nn.Module):
128
+ """Convolutional Residual Block 2D.
129
+
130
+ This block stacks two convolutional layers with batch normalization,
131
+ max pooling, dropout, and residual connection.
132
+
133
+ Parameters
134
+ ----------
135
+ in_channels : int
136
+ Number of input channels.
137
+ out_channels : int
138
+ Number of output channels.
139
+ stride : int (default=1)
140
+ Stride of the convolutional layers.
141
+ use_downsampling : bool (default=True)
142
+ Whether to use a downsampling residual connection.
143
+ pooling : bool (default=True)
144
+ Whether to use max pooling.
145
+ kernel_size : int (default=3)
146
+ Kernel size of the convolutional layers.
147
+ padding : int (default=1)
148
+ Padding of the convolutional layers.
149
+ activation: nn.Module, default=nn.ELU
150
+ Activation function class to apply. Should be a PyTorch activation
151
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
152
+ drop_prob : float, default=0.5
153
+ The dropout rate for regularization. Values should be between 0 and 1.
154
+
155
+ Examples
156
+ --------
157
+ >>> import torch
158
+ >>> model = ResBlock2D(6, 16, 1, True, True)
159
+ >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width)
160
+ >>> output = model(input_)
161
+ >>> output.shape
162
+ torch.Size([16, 16, 14, 75])
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ in_channels,
168
+ out_channels,
169
+ stride=1,
170
+ use_downsampling=True,
171
+ pooling=True,
172
+ kernel_size=3,
173
+ padding=1,
174
+ drop_prob=0.5,
175
+ activation: nn.Module = nn.ReLU,
176
+ ):
177
+ super().__init__()
178
+ self.conv1 = nn.Conv2d(
179
+ in_channels=in_channels,
180
+ out_channels=out_channels,
181
+ kernel_size=kernel_size,
182
+ stride=stride,
183
+ padding=padding,
184
+ )
185
+ self.bn1 = nn.BatchNorm2d(out_channels)
186
+ self.relu = activation()
187
+ self.conv2 = nn.Conv2d(
188
+ in_channels=out_channels,
189
+ out_channels=out_channels,
190
+ kernel_size=kernel_size,
191
+ padding=padding,
192
+ )
193
+ self.bn2 = nn.BatchNorm2d(out_channels)
194
+ self.maxpool = nn.MaxPool2d(
195
+ kernel_size=kernel_size, stride=stride, padding=padding
196
+ )
197
+ self.downsample = nn.Sequential(
198
+ nn.Conv2d(
199
+ in_channels=in_channels,
200
+ out_channels=out_channels,
201
+ kernel_size=kernel_size,
202
+ stride=stride,
203
+ padding=padding,
204
+ ),
205
+ nn.BatchNorm2d(out_channels),
206
+ )
207
+ self.use_downsampling = use_downsampling
208
+ self.pooling = pooling
209
+ self.dropout = nn.Dropout(drop_prob)
210
+
211
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
212
+ """
213
+
214
+ Parameters
215
+ ----------
216
+ X: Tensor
217
+ Input tensor of shape (batch_size, n_channels, n_freqs, n_times).
218
+
219
+ Returns
220
+ -------
221
+ Tensor
222
+ Output tensor of shape (batch_size, n_channels, n_freqs, n_times).
223
+ """
224
+ out = self.conv1(x)
225
+ out = self.bn1(out)
226
+ out = self.relu(out)
227
+ out = self.conv2(out)
228
+ out = self.bn2(out)
229
+ if self.use_downsampling:
230
+ residual = self.downsample(x)
231
+ out += residual
232
+ if self.pooling:
233
+ out = self.maxpool(out)
234
+ out = self.dropout(out)
235
+ return out
236
+
237
+
238
+ class _STFTModule(nn.Module):
239
+ """
240
+ A PyTorch module that computes the Short-Time Fourier Transform (STFT)
241
+ of an EEG batch tensor.
242
+
243
+ Expects input of shape (batch_size, n_channels, n_times) and returns
244
+ (batch_size, n_channels, n_freqs, n_times).
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ n_fft: int,
250
+ hop_length: int,
251
+ center: bool = True,
252
+ onesided: bool = True,
253
+ return_complex: bool = True,
254
+ normalized: bool = True,
255
+ ):
256
+ """
257
+ Parameters
258
+ ----------
259
+ n_fft : int
260
+ Number of FFT points (window size).
261
+ steps : int
262
+ Number of hops per window (i.e. hop_length = n_fft // steps).
263
+ """
264
+ super().__init__()
265
+ self.n_fft = n_fft
266
+ self.hop_length = hop_length
267
+ self.center = center
268
+ self.one_sided = onesided
269
+ self.return_complex = return_complex
270
+ self.normalized = normalized
271
+
272
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
273
+ window = torch.ones(self.n_fft, device=x.device)
274
+
275
+ # x: (B, C, T)
276
+ B, C, T = x.shape
277
+ # flatten batch & channel into one dim
278
+ x_flat = x.reshape(B * C, T)
279
+
280
+ # compute stft on 2D tensor
281
+ spec_flat = torch.stft(
282
+ x_flat,
283
+ n_fft=self.n_fft,
284
+ hop_length=self.hop_length,
285
+ win_length=self.n_fft,
286
+ window=window,
287
+ normalized=self.normalized,
288
+ center=self.center,
289
+ onesided=self.one_sided,
290
+ return_complex=self.return_complex,
291
+ )
292
+
293
+ F, L = spec_flat.shape[-2], spec_flat.shape[-1]
294
+ spec = spec_flat.view(B, C, F, L)
295
+
296
+ return torch.abs(spec)