braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,133 @@
4
4
  import torch
5
5
  import torch.nn as nn
6
6
 
7
- from .base import EEGModuleMixin, deprecated_args
7
+ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
- class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal information
11
- def __init__(self):
10
+ class DeepSleepNet(EEGModuleMixin, nn.Module):
11
+ """Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
12
+
13
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
14
+ :align: center
15
+ :alt: DeepSleepNet Architecture
16
+
17
+ Convolutional neural network and bidirectional-Long Short-Term
18
+ for single channels sleep staging described in [Supratak2017]_.
19
+
20
+ Parameters
21
+ ----------
22
+ activation_large: nn.Module, default=nn.ELU
23
+ Activation function class to apply. Should be a PyTorch activation
24
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
25
+ activation_small: nn.Module, default=nn.ReLU
26
+ Activation function class to apply. Should be a PyTorch activation
27
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
28
+ return_feats : bool
29
+ If True, return the features, i.e. the output of the feature extractor
30
+ (before the final linear layer). If False, pass the features through
31
+ the final linear layer.
32
+ drop_prob : float, default=0.5
33
+ The dropout rate for regularization. Values should be between 0 and 1.
34
+
35
+
36
+ References
37
+ ----------
38
+ .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
39
+ DeepSleepNet: A model for automatic sleep stage scoring based
40
+ on raw single-channel EEG. IEEE Transactions on Neural Systems
41
+ and Rehabilitation Engineering, 25(11), 1998-2008.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ n_outputs=5,
47
+ return_feats=False,
48
+ n_chans=None,
49
+ chs_info=None,
50
+ n_times=None,
51
+ input_window_seconds=None,
52
+ sfreq=None,
53
+ activation_large: nn.Module = nn.ELU,
54
+ activation_small: nn.Module = nn.ReLU,
55
+ drop_prob: float = 0.5,
56
+ ):
57
+ super().__init__(
58
+ n_outputs=n_outputs,
59
+ n_chans=n_chans,
60
+ chs_info=chs_info,
61
+ n_times=n_times,
62
+ input_window_seconds=input_window_seconds,
63
+ sfreq=sfreq,
64
+ )
65
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
66
+ self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
67
+ self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
68
+ self.dropout = nn.Dropout(0.5)
69
+ self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
70
+ self.fc = nn.Sequential(
71
+ nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
72
+ )
73
+
74
+ self.features_extractor = nn.Identity()
75
+ self.len_last_layer = 1024
76
+ self.return_feats = return_feats
77
+
78
+ # TODO: Add new way to handle return_features == True
79
+ if not return_feats:
80
+ self.final_layer = nn.Linear(1024, self.n_outputs)
81
+ else:
82
+ self.final_layer = nn.Identity()
83
+
84
+ def forward(self, x):
85
+ """Forward pass.
86
+
87
+ Parameters
88
+ ----------
89
+ x: torch.Tensor
90
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
91
+ """
92
+
93
+ if x.ndim == 3:
94
+ x = x.unsqueeze(1)
95
+
96
+ x1 = self.cnn1(x)
97
+ x1 = x1.flatten(start_dim=1)
98
+
99
+ x2 = self.cnn2(x)
100
+ x2 = x2.flatten(start_dim=1)
101
+
102
+ x = torch.cat((x1, x2), dim=1)
103
+ x = self.dropout(x)
104
+ temp = x.clone()
105
+ temp = self.fc(temp)
106
+ x = x.unsqueeze(1)
107
+ x = self.bilstm(x)
108
+ x = x.squeeze()
109
+ x = torch.add(x, temp)
110
+ x = self.dropout(x)
111
+
112
+ feats = self.features_extractor(x)
113
+
114
+ if self.return_feats:
115
+ return feats
116
+ else:
117
+ return self.final_layer(feats)
118
+
119
+
120
+ class _SmallCNN(nn.Module):
121
+ """
122
+ Smaller filter sizes to learn temporal information.
123
+
124
+ Parameters
125
+ ----------
126
+ activation: nn.Module, default=nn.ReLU
127
+ Activation function class to apply. Should be a PyTorch activation
128
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
129
+ drop_prob : float, default=0.5
130
+ The dropout rate for regularization. Values should be between 0 and 1.
131
+ """
132
+
133
+ def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
12
134
  super().__init__()
13
135
  self.conv1 = nn.Sequential(
14
136
  nn.Conv2d(
@@ -20,10 +142,10 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
20
142
  bias=False,
21
143
  ),
22
144
  nn.BatchNorm2d(num_features=64),
23
- nn.ReLU(),
145
+ activation(),
24
146
  )
25
147
  self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
26
- self.dropout = nn.Dropout(p=0.5)
148
+ self.dropout = nn.Dropout(p=drop_prob)
27
149
  self.conv2 = nn.Sequential(
28
150
  nn.Conv2d(
29
151
  in_channels=64,
@@ -34,7 +156,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
34
156
  bias=False,
35
157
  ),
36
158
  nn.BatchNorm2d(num_features=128),
37
- nn.ReLU(),
159
+ activation(),
38
160
  )
39
161
  self.conv3 = nn.Sequential(
40
162
  nn.Conv2d(
@@ -46,7 +168,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
46
168
  bias=False,
47
169
  ),
48
170
  nn.BatchNorm2d(num_features=128),
49
- nn.ReLU(),
171
+ activation(),
50
172
  )
51
173
  self.conv4 = nn.Sequential(
52
174
  nn.Conv2d(
@@ -58,7 +180,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
58
180
  bias=False,
59
181
  ),
60
182
  nn.BatchNorm2d(num_features=128),
61
- nn.ReLU(),
183
+ activation(),
62
184
  )
63
185
  self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
64
186
 
@@ -72,8 +194,19 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
72
194
  return x
73
195
 
74
196
 
75
- class _LargeCNN(nn.Module): # larger filter sizes to learn frequency information
76
- def __init__(self):
197
+ class _LargeCNN(nn.Module):
198
+ """
199
+ Larger filter sizes to learn frequency information.
200
+
201
+ Parameters
202
+ ----------
203
+ activation: nn.Module, default=nn.ELU
204
+ Activation function class to apply. Should be a PyTorch activation
205
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
206
+
207
+ """
208
+
209
+ def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
77
210
  super().__init__()
78
211
 
79
212
  self.conv1 = nn.Sequential(
@@ -86,10 +219,10 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
86
219
  bias=False,
87
220
  ),
88
221
  nn.BatchNorm2d(num_features=64),
89
- nn.ReLU(),
222
+ activation(),
90
223
  )
91
224
  self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
92
- self.dropout = nn.Dropout(p=0.5)
225
+ self.dropout = nn.Dropout(p=drop_prob)
93
226
  self.conv2 = nn.Sequential(
94
227
  nn.Conv2d(
95
228
  in_channels=64,
@@ -100,7 +233,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
100
233
  bias=False,
101
234
  ),
102
235
  nn.BatchNorm2d(num_features=128),
103
- nn.ReLU(),
236
+ activation(),
104
237
  )
105
238
  self.conv3 = nn.Sequential(
106
239
  nn.Conv2d(
@@ -112,7 +245,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
112
245
  bias=False,
113
246
  ),
114
247
  nn.BatchNorm2d(num_features=128),
115
- nn.ReLU(),
248
+ activation(),
116
249
  )
117
250
  self.conv4 = nn.Sequential(
118
251
  nn.Conv2d(
@@ -124,7 +257,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
124
257
  bias=False,
125
258
  ),
126
259
  nn.BatchNorm2d(num_features=128),
127
- nn.ReLU(),
260
+ activation(),
128
261
  )
129
262
  self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
130
263
 
@@ -154,112 +287,9 @@ class _BiLSTM(nn.Module):
154
287
 
155
288
  def forward(self, x):
156
289
  # set initial hidden and cell states
157
- h0 = torch.zeros(
158
- self.num_layers * 2, x.size(0), self.hidden_size
159
- ).to(x.device)
290
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
160
291
  c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
161
292
 
162
293
  # forward propagate LSTM
163
294
  out, _ = self.lstm(x, (h0, c0))
164
295
  return out
165
-
166
-
167
- class DeepSleepNet(EEGModuleMixin, nn.Module):
168
- """Sleep staging architecture from Supratak et al 2017.
169
-
170
- Convolutional neural network and bidirectional-Long Short-Term
171
- for single channels sleep staging described in [Supratak2017]_.
172
-
173
- Parameters
174
- ----------
175
- return_feats : bool
176
- If True, return the features, i.e. the output of the feature extractor
177
- (before the final linear layer). If False, pass the features through
178
- the final linear layer.
179
- n_classes :
180
- Alias for n_outputs.
181
-
182
- References
183
- ----------
184
- .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
185
- DeepSleepNet: A model for automatic sleep stage scoring based
186
- on raw single-channel EEG. IEEE Transactions on Neural Systems
187
- and Rehabilitation Engineering, 25(11), 1998-2008.
188
- """
189
-
190
- def __init__(
191
- self,
192
- n_outputs=5,
193
- return_feats=False,
194
- n_chans=None,
195
- chs_info=None,
196
- n_times=None,
197
- input_window_seconds=None,
198
- sfreq=None,
199
- n_classes=None,
200
- ):
201
- n_outputs, = deprecated_args(
202
- self,
203
- ('n_classes', 'n_outputs', n_classes, n_outputs),
204
- )
205
- super().__init__(
206
- n_outputs=n_outputs,
207
- n_chans=n_chans,
208
- chs_info=chs_info,
209
- n_times=n_times,
210
- input_window_seconds=input_window_seconds,
211
- sfreq=sfreq,
212
- )
213
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
214
- del n_classes
215
- self.cnn1 = _SmallCNN()
216
- self.cnn2 = _LargeCNN()
217
- self.dropout = nn.Dropout(0.5)
218
- self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
219
- self.fc = nn.Sequential(nn.Linear(3072, 1024, bias=False),
220
- nn.BatchNorm1d(num_features=1024))
221
-
222
- self.features_extractor = nn.Identity()
223
- self.len_last_layer = 1024
224
- self.return_feats = return_feats
225
-
226
- # TODO: Add new way to handle return_features == True
227
- if not return_feats:
228
- self.final_layer = nn.Linear(1024, self.n_outputs)
229
- else:
230
- self.final_layer = nn.Identity()
231
-
232
- def forward(self, x):
233
- """Forward pass.
234
-
235
- Parameters
236
- ----------
237
- x: torch.Tensor
238
- Batch of EEG windows of shape (batch_size, n_channels, n_times).
239
- """
240
-
241
- if x.ndim == 3:
242
- x = x.unsqueeze(1)
243
-
244
- x1 = self.cnn1(x)
245
- x1 = x1.flatten(start_dim=1)
246
-
247
- x2 = self.cnn2(x)
248
- x2 = x2.flatten(start_dim=1)
249
-
250
- x = torch.cat((x1, x2), dim=1)
251
- x = self.dropout(x)
252
- temp = x.clone()
253
- temp = self.fc(temp)
254
- x = x.unsqueeze(1)
255
- x = self.bilstm(x)
256
- x = x.squeeze()
257
- x = torch.add(x, temp)
258
- x = self.dropout(x)
259
-
260
- feats = self.features_extractor(x)
261
-
262
- if self.return_feats:
263
- return feats
264
- else:
265
- return self.final_layer(feats)