braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,30 @@
1
1
  # Authors: Yonghao Song <eeyhsong@gmail.com>
2
2
  #
3
3
  # License: BSD (3-clause)
4
+ import warnings
5
+ from typing import Optional
6
+
4
7
  import torch
5
8
  import torch.nn.functional as F
6
9
  from einops import rearrange
7
10
  from einops.layers.torch import Rearrange
8
- from torch import nn, Tensor
9
- import warnings
11
+ from torch import Tensor, nn
10
12
 
11
- from .base import EEGModuleMixin, deprecated_args
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
12
15
 
13
16
 
14
17
  class EEGConformer(EEGModuleMixin, nn.Module):
15
- """EEG Conformer.
18
+ """EEG Conformer from Song et al. (2022) from [song2022]_.
19
+
20
+ .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
21
+ :align: center
22
+ :alt: EEGConformer Architecture
16
23
 
17
24
  Convolutional Transformer for EEG decoding.
18
25
 
19
26
  The paper and original code with more details about the methodological
20
- choices are available at the [Song2022]_ and [ConformerCode]_.
27
+ choices are available at the [song2022]_ and [ConformerCode]_.
21
28
 
22
29
  This neural network architecture receives a traditional braindecode input.
23
30
  The input shape should be three-dimensional matrix representing the EEG
@@ -68,15 +75,15 @@ class EEGConformer(EEGModuleMixin, nn.Module):
68
75
  return_features: bool
69
76
  If True, the forward method returns the features before the
70
77
  last classification layer. Defaults to False.
71
- n_classes :
72
- Alias for n_outputs.
73
- n_channels :
74
- Alias for n_chans.
75
- input_window_samples :
76
- Alias for n_times.
78
+ activation: nn.Module
79
+ Activation function as parameter. Default is nn.ELU
80
+ activation_transfor: nn.Module
81
+ Activation function as parameter, applied at the FeedForwardBlock module
82
+ inside the transformer. Default is nn.GeLU
83
+
77
84
  References
78
85
  ----------
79
- .. [Song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
86
+ .. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
80
87
  conformer: Convolutional transformer for EEG decoding and visualization.
81
88
  IEEE Transactions on Neural Systems and Rehabilitation Engineering,
82
89
  31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
@@ -86,34 +93,26 @@ class EEGConformer(EEGModuleMixin, nn.Module):
86
93
  """
87
94
 
88
95
  def __init__(
89
- self,
90
- n_outputs=None,
91
- n_chans=None,
92
- n_filters_time=40,
93
- filter_time_length=25,
94
- pool_time_length=75,
95
- pool_time_stride=15,
96
- drop_prob=0.5,
97
- att_depth=6,
98
- att_heads=10,
99
- att_drop_prob=0.5,
100
- final_fc_length=2440,
101
- return_features=False,
102
- n_times=None,
103
- chs_info=None,
104
- input_window_seconds=None,
105
- sfreq=None,
106
- n_classes=None,
107
- n_channels=None,
108
- input_window_samples=None,
109
- add_log_softmax=True,
96
+ self,
97
+ n_outputs=None,
98
+ n_chans=None,
99
+ n_filters_time=40,
100
+ filter_time_length=25,
101
+ pool_time_length=75,
102
+ pool_time_stride=15,
103
+ drop_prob=0.5,
104
+ att_depth=6,
105
+ att_heads=10,
106
+ att_drop_prob=0.5,
107
+ final_fc_length="auto",
108
+ return_features=False,
109
+ activation: nn.Module = nn.ELU,
110
+ activation_transfor: nn.Module = nn.GELU,
111
+ n_times=None,
112
+ chs_info=None,
113
+ input_window_seconds=None,
114
+ sfreq=None,
110
115
  ):
111
- n_outputs, n_chans, n_times = deprecated_args(
112
- self,
113
- ('n_classes', 'n_outputs', n_classes, n_outputs),
114
- ('n_channels', 'n_chans', n_channels, n_chans),
115
- ('input_window_samples', 'n_times', input_window_samples, n_times)
116
- )
117
116
  super().__init__(
118
117
  n_outputs=n_outputs,
119
118
  n_chans=n_chans,
@@ -121,19 +120,22 @@ class EEGConformer(EEGModuleMixin, nn.Module):
121
120
  n_times=n_times,
122
121
  input_window_seconds=input_window_seconds,
123
122
  sfreq=sfreq,
124
- add_log_softmax=add_log_softmax,
125
123
  )
126
124
  self.mapping = {
127
- 'classification_head.fc.6.weight': 'final_layer.final_layer.0.weight',
128
- 'classification_head.fc.6.bias': 'final_layer.final_layer.0.bias'
125
+ "classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
126
+ "classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
129
127
  }
130
128
 
131
129
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
132
- del n_classes, n_channels, input_window_samples
133
130
  if not (self.n_chans <= 64):
134
- warnings.warn("This model has only been tested on no more " +
135
- "than 64 channels. no guarantee to work with " +
136
- "more channels.", UserWarning)
131
+ warnings.warn(
132
+ "This model has only been tested on no more "
133
+ + "than 64 channels. no guarantee to work with "
134
+ + "more channels.",
135
+ UserWarning,
136
+ )
137
+
138
+ self.return_features = return_features
137
139
 
138
140
  self.patch_embedding = _PatchEmbedding(
139
141
  n_filters_time=n_filters_time,
@@ -141,38 +143,42 @@ class EEGConformer(EEGModuleMixin, nn.Module):
141
143
  n_channels=self.n_chans,
142
144
  pool_time_length=pool_time_length,
143
145
  stride_avg_pool=pool_time_stride,
144
- drop_prob=drop_prob)
146
+ drop_prob=drop_prob,
147
+ activation=activation,
148
+ )
145
149
 
146
150
  if final_fc_length == "auto":
147
151
  assert self.n_times is not None
148
- final_fc_length = self.get_fc_size()
152
+ self.final_fc_length = self.get_fc_size()
149
153
 
150
154
  self.transformer = _TransformerEncoder(
151
155
  att_depth=att_depth,
152
156
  emb_size=n_filters_time,
153
157
  att_heads=att_heads,
154
- att_drop=att_drop_prob)
158
+ att_drop=att_drop_prob,
159
+ activation=activation_transfor,
160
+ )
155
161
 
156
162
  self.fc = _FullyConnected(
157
- final_fc_length=final_fc_length)
163
+ final_fc_length=self.final_fc_length, activation=activation
164
+ )
158
165
 
159
- self.final_layer = _FinalLayer(n_classes=self.n_outputs,
160
- return_features=return_features,
161
- add_log_softmax=self.add_log_softmax)
166
+ self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
162
167
 
163
168
  def forward(self, x: Tensor) -> Tensor:
164
169
  x = torch.unsqueeze(x, dim=1) # add one extra dimension
165
170
  x = self.patch_embedding(x)
166
- x = self.transformer(x)
167
- x = self.fc(x)
171
+ feature = self.transformer(x)
172
+
173
+ if self.return_features:
174
+ return feature
175
+
176
+ x = self.fc(feature)
168
177
  x = self.final_layer(x)
169
178
  return x
170
179
 
171
180
  def get_fc_size(self):
172
-
173
- out = self.patch_embedding(torch.ones((1, 1,
174
- self.n_chans,
175
- self.n_times)))
181
+ out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
176
182
  size_embedding_1 = out.cpu().data.numpy().shape[1]
177
183
  size_embedding_2 = out.cpu().data.numpy().shape[2]
178
184
 
@@ -207,26 +213,24 @@ class _PatchEmbedding(nn.Module):
207
213
  """
208
214
 
209
215
  def __init__(
210
- self,
211
- n_filters_time,
212
- filter_time_length,
213
- n_channels,
214
- pool_time_length,
215
- stride_avg_pool,
216
- drop_prob,
216
+ self,
217
+ n_filters_time,
218
+ filter_time_length,
219
+ n_channels,
220
+ pool_time_length,
221
+ stride_avg_pool,
222
+ drop_prob,
223
+ activation: nn.Module = nn.ELU,
217
224
  ):
218
225
  super().__init__()
219
226
 
220
227
  self.shallownet = nn.Sequential(
221
- nn.Conv2d(1, n_filters_time,
222
- (1, filter_time_length), (1, 1)),
223
- nn.Conv2d(n_filters_time, n_filters_time,
224
- (n_channels, 1), (1, 1)),
228
+ nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
229
+ nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
225
230
  nn.BatchNorm2d(num_features=n_filters_time),
226
- nn.ELU(),
231
+ activation(),
227
232
  nn.AvgPool2d(
228
- kernel_size=(1, pool_time_length),
229
- stride=(1, stride_avg_pool)
233
+ kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
230
234
  ),
231
235
  # pooling acts as slicing to obtain 'patch' along the
232
236
  # time dimension as in ViT
@@ -246,79 +250,43 @@ class _PatchEmbedding(nn.Module):
246
250
  return x
247
251
 
248
252
 
249
- class _MultiHeadAttention(nn.Module):
250
- def __init__(self, emb_size, num_heads, dropout):
251
- super().__init__()
252
- self.emb_size = emb_size
253
- self.num_heads = num_heads
254
- self.keys = nn.Linear(emb_size, emb_size)
255
- self.queries = nn.Linear(emb_size, emb_size)
256
- self.values = nn.Linear(emb_size, emb_size)
257
- self.att_drop = nn.Dropout(dropout)
258
- self.projection = nn.Linear(emb_size, emb_size)
259
-
260
- def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
261
- queries = rearrange(
262
- self.queries(x), "b n (h d) -> b h n d", h=self.num_heads
263
- )
264
- keys = rearrange(
265
- self.keys(x), "b n (h d) -> b h n d", h=self.num_heads
266
- )
267
- values = rearrange(
268
- self.values(x), "b n (h d) -> b h n d", h=self.num_heads
269
- )
270
- energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
271
- if mask is not None:
272
- fill_value = torch.finfo(torch.float32).min
273
- energy.mask_fill(~mask, fill_value)
274
-
275
- scaling = self.emb_size ** (1 / 2)
276
- att = F.softmax(energy / scaling, dim=-1)
277
- att = self.att_drop(att)
278
- out = torch.einsum("bhal, bhlv -> bhav ", att, values)
279
- out = rearrange(out, "b h n d -> b n (h d)")
280
- out = self.projection(out)
281
- return out
282
-
283
-
284
253
  class _ResidualAdd(nn.Module):
285
254
  def __init__(self, fn):
286
255
  super().__init__()
287
256
  self.fn = fn
288
257
 
289
- def forward(self, x, **kwargs):
258
+ def forward(self, x):
290
259
  res = x
291
- x = self.fn(x, **kwargs)
260
+ x = self.fn(x)
292
261
  x += res
293
262
  return x
294
263
 
295
264
 
296
- class _FeedForwardBlock(nn.Sequential):
297
- def __init__(self, emb_size, expansion, drop_p):
298
- super().__init__(
299
- nn.Linear(emb_size, expansion * emb_size),
300
- nn.GELU(),
301
- nn.Dropout(drop_p),
302
- nn.Linear(expansion * emb_size, emb_size),
303
- )
304
-
305
-
306
265
  class _TransformerEncoderBlock(nn.Sequential):
307
- def __init__(self, emb_size, att_heads, att_drop, forward_expansion=4):
266
+ def __init__(
267
+ self,
268
+ emb_size,
269
+ att_heads,
270
+ att_drop,
271
+ forward_expansion=4,
272
+ activation: nn.Module = nn.GELU,
273
+ ):
308
274
  super().__init__(
309
275
  _ResidualAdd(
310
276
  nn.Sequential(
311
277
  nn.LayerNorm(emb_size),
312
- _MultiHeadAttention(emb_size, att_heads, att_drop),
278
+ MultiHeadAttention(emb_size, att_heads, att_drop),
313
279
  nn.Dropout(att_drop),
314
280
  )
315
281
  ),
316
282
  _ResidualAdd(
317
283
  nn.Sequential(
318
284
  nn.LayerNorm(emb_size),
319
- _FeedForwardBlock(
320
- emb_size, expansion=forward_expansion,
321
- drop_p=att_drop
285
+ FeedForwardBlock(
286
+ emb_size,
287
+ expansion=forward_expansion,
288
+ drop_p=att_drop,
289
+ activation=activation,
322
290
  ),
323
291
  nn.Dropout(att_drop),
324
292
  )
@@ -344,19 +312,29 @@ class _TransformerEncoder(nn.Sequential):
344
312
 
345
313
  """
346
314
 
347
- def __init__(self, att_depth, emb_size, att_heads, att_drop):
315
+ def __init__(
316
+ self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
317
+ ):
348
318
  super().__init__(
349
319
  *[
350
- _TransformerEncoderBlock(emb_size, att_heads, att_drop)
320
+ _TransformerEncoderBlock(
321
+ emb_size, att_heads, att_drop, activation=activation
322
+ )
351
323
  for _ in range(att_depth)
352
324
  ]
353
325
  )
354
326
 
355
327
 
356
328
  class _FullyConnected(nn.Module):
357
- def __init__(self, final_fc_length,
358
- drop_prob_1=0.5, drop_prob_2=0.3, out_channels=256,
359
- hidden_channels=32):
329
+ def __init__(
330
+ self,
331
+ final_fc_length,
332
+ drop_prob_1=0.5,
333
+ drop_prob_2=0.3,
334
+ out_channels=256,
335
+ hidden_channels=32,
336
+ activation: nn.Module = nn.ELU,
337
+ ):
360
338
  """Fully-connected layer for the transformer encoder.
361
339
 
362
340
  Parameters
@@ -375,17 +353,16 @@ class _FullyConnected(nn.Module):
375
353
  Number of output channels for the second linear layer.
376
354
  return_features : bool
377
355
  Whether to return input features.
378
- add_log_softmax: bool
379
- Whether to add LogSoftmax non-linearity as the final layer.
380
356
  """
381
357
 
382
358
  super().__init__()
359
+ self.hidden_channels = hidden_channels
383
360
  self.fc = nn.Sequential(
384
361
  nn.Linear(final_fc_length, out_channels),
385
- nn.ELU(),
362
+ activation(),
386
363
  nn.Dropout(drop_prob_1),
387
364
  nn.Linear(out_channels, hidden_channels),
388
- nn.ELU(),
365
+ activation(),
389
366
  nn.Dropout(drop_prob_2),
390
367
  )
391
368
 
@@ -393,40 +370,3 @@ class _FullyConnected(nn.Module):
393
370
  x = x.contiguous().view(x.size(0), -1)
394
371
  out = self.fc(x)
395
372
  return out
396
-
397
-
398
- class _FinalLayer(nn.Module):
399
- def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
400
- """Classification head for the transformer encoder.
401
-
402
- Parameters
403
- ----------
404
- n_classes : int
405
- Number of classes for classification.
406
- hidden_channels : int
407
- Number of output channels for the second linear layer.
408
- return_features : bool
409
- Whether to return input features.
410
- add_log_softmax : bool
411
- Adding LogSoftmax or not.
412
- """
413
-
414
- super().__init__()
415
- self.final_layer = nn.Sequential(
416
- nn.Linear(hidden_channels, n_classes),
417
- )
418
- self.return_features = return_features
419
- if add_log_softmax:
420
- classification = nn.LogSoftmax(dim=1)
421
- else:
422
- classification = nn.Identity()
423
- if not self.return_features:
424
- self.final_layer.add_module("classification", classification)
425
-
426
- def forward(self, x):
427
- if self.return_features:
428
- out = self.final_layer(x)
429
- return out, x
430
- else:
431
- out = self.final_layer(x)
432
- return out