braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,27 @@
1
1
  # Authors: Yonghao Song <eeyhsong@gmail.com>
2
2
  #
3
3
  # License: BSD (3-clause)
4
+ import warnings
5
+
4
6
  import torch
5
- import torch.nn.functional as F
6
- from einops import rearrange
7
7
  from einops.layers.torch import Rearrange
8
- from torch import nn, Tensor
9
- import warnings
8
+ from torch import Tensor, nn
10
9
 
11
- from .base import EEGModuleMixin, deprecated_args
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
12
12
 
13
13
 
14
14
  class EEGConformer(EEGModuleMixin, nn.Module):
15
- """EEG Conformer.
15
+ """EEG Conformer from Song et al. (2022) from [song2022]_.
16
+
17
+ .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
18
+ :align: center
19
+ :alt: EEGConformer Architecture
16
20
 
17
21
  Convolutional Transformer for EEG decoding.
18
22
 
19
23
  The paper and original code with more details about the methodological
20
- choices are available at the [Song2022]_ and [ConformerCode]_.
24
+ choices are available at the [song2022]_ and [ConformerCode]_.
21
25
 
22
26
  This neural network architecture receives a traditional braindecode input.
23
27
  The input shape should be three-dimensional matrix representing the EEG
@@ -68,15 +72,15 @@ class EEGConformer(EEGModuleMixin, nn.Module):
68
72
  return_features: bool
69
73
  If True, the forward method returns the features before the
70
74
  last classification layer. Defaults to False.
71
- n_classes :
72
- Alias for n_outputs.
73
- n_channels :
74
- Alias for n_chans.
75
- input_window_samples :
76
- Alias for n_times.
75
+ activation: nn.Module
76
+ Activation function as parameter. Default is nn.ELU
77
+ activation_transfor: nn.Module
78
+ Activation function as parameter, applied at the FeedForwardBlock module
79
+ inside the transformer. Default is nn.GeLU
80
+
77
81
  References
78
82
  ----------
79
- .. [Song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
83
+ .. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
80
84
  conformer: Convolutional transformer for EEG decoding and visualization.
81
85
  IEEE Transactions on Neural Systems and Rehabilitation Engineering,
82
86
  31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
@@ -86,34 +90,26 @@ class EEGConformer(EEGModuleMixin, nn.Module):
86
90
  """
87
91
 
88
92
  def __init__(
89
- self,
90
- n_outputs=None,
91
- n_chans=None,
92
- n_filters_time=40,
93
- filter_time_length=25,
94
- pool_time_length=75,
95
- pool_time_stride=15,
96
- drop_prob=0.5,
97
- att_depth=6,
98
- att_heads=10,
99
- att_drop_prob=0.5,
100
- final_fc_length=2440,
101
- return_features=False,
102
- n_times=None,
103
- chs_info=None,
104
- input_window_seconds=None,
105
- sfreq=None,
106
- n_classes=None,
107
- n_channels=None,
108
- input_window_samples=None,
109
- add_log_softmax=True,
93
+ self,
94
+ n_outputs=None,
95
+ n_chans=None,
96
+ n_filters_time=40,
97
+ filter_time_length=25,
98
+ pool_time_length=75,
99
+ pool_time_stride=15,
100
+ drop_prob=0.5,
101
+ att_depth=6,
102
+ att_heads=10,
103
+ att_drop_prob=0.5,
104
+ final_fc_length="auto",
105
+ return_features=False,
106
+ activation: nn.Module = nn.ELU,
107
+ activation_transfor: nn.Module = nn.GELU,
108
+ n_times=None,
109
+ chs_info=None,
110
+ input_window_seconds=None,
111
+ sfreq=None,
110
112
  ):
111
- n_outputs, n_chans, n_times = deprecated_args(
112
- self,
113
- ('n_classes', 'n_outputs', n_classes, n_outputs),
114
- ('n_channels', 'n_chans', n_channels, n_chans),
115
- ('input_window_samples', 'n_times', input_window_samples, n_times)
116
- )
117
113
  super().__init__(
118
114
  n_outputs=n_outputs,
119
115
  n_chans=n_chans,
@@ -121,19 +117,22 @@ class EEGConformer(EEGModuleMixin, nn.Module):
121
117
  n_times=n_times,
122
118
  input_window_seconds=input_window_seconds,
123
119
  sfreq=sfreq,
124
- add_log_softmax=add_log_softmax,
125
120
  )
126
121
  self.mapping = {
127
- 'classification_head.fc.6.weight': 'final_layer.final_layer.0.weight',
128
- 'classification_head.fc.6.bias': 'final_layer.final_layer.0.bias'
122
+ "classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
123
+ "classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
129
124
  }
130
125
 
131
126
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
132
- del n_classes, n_channels, input_window_samples
133
127
  if not (self.n_chans <= 64):
134
- warnings.warn("This model has only been tested on no more " +
135
- "than 64 channels. no guarantee to work with " +
136
- "more channels.", UserWarning)
128
+ warnings.warn(
129
+ "This model has only been tested on no more "
130
+ + "than 64 channels. no guarantee to work with "
131
+ + "more channels.",
132
+ UserWarning,
133
+ )
134
+
135
+ self.return_features = return_features
137
136
 
138
137
  self.patch_embedding = _PatchEmbedding(
139
138
  n_filters_time=n_filters_time,
@@ -141,38 +140,44 @@ class EEGConformer(EEGModuleMixin, nn.Module):
141
140
  n_channels=self.n_chans,
142
141
  pool_time_length=pool_time_length,
143
142
  stride_avg_pool=pool_time_stride,
144
- drop_prob=drop_prob)
143
+ drop_prob=drop_prob,
144
+ activation=activation,
145
+ )
145
146
 
146
147
  if final_fc_length == "auto":
147
148
  assert self.n_times is not None
148
- final_fc_length = self.get_fc_size()
149
+ self.final_fc_length = self.get_fc_size()
150
+ else:
151
+ self.final_fc_length = final_fc_length
149
152
 
150
153
  self.transformer = _TransformerEncoder(
151
154
  att_depth=att_depth,
152
155
  emb_size=n_filters_time,
153
156
  att_heads=att_heads,
154
- att_drop=att_drop_prob)
157
+ att_drop=att_drop_prob,
158
+ activation=activation_transfor,
159
+ )
155
160
 
156
161
  self.fc = _FullyConnected(
157
- final_fc_length=final_fc_length)
162
+ final_fc_length=self.final_fc_length, activation=activation
163
+ )
158
164
 
159
- self.final_layer = _FinalLayer(n_classes=self.n_outputs,
160
- return_features=return_features,
161
- add_log_softmax=self.add_log_softmax)
165
+ self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
162
166
 
163
167
  def forward(self, x: Tensor) -> Tensor:
164
168
  x = torch.unsqueeze(x, dim=1) # add one extra dimension
165
169
  x = self.patch_embedding(x)
166
- x = self.transformer(x)
167
- x = self.fc(x)
170
+ feature = self.transformer(x)
171
+
172
+ if self.return_features:
173
+ return feature
174
+
175
+ x = self.fc(feature)
168
176
  x = self.final_layer(x)
169
177
  return x
170
178
 
171
179
  def get_fc_size(self):
172
-
173
- out = self.patch_embedding(torch.ones((1, 1,
174
- self.n_chans,
175
- self.n_times)))
180
+ out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
176
181
  size_embedding_1 = out.cpu().data.numpy().shape[1]
177
182
  size_embedding_2 = out.cpu().data.numpy().shape[2]
178
183
 
@@ -207,26 +212,24 @@ class _PatchEmbedding(nn.Module):
207
212
  """
208
213
 
209
214
  def __init__(
210
- self,
211
- n_filters_time,
212
- filter_time_length,
213
- n_channels,
214
- pool_time_length,
215
- stride_avg_pool,
216
- drop_prob,
215
+ self,
216
+ n_filters_time,
217
+ filter_time_length,
218
+ n_channels,
219
+ pool_time_length,
220
+ stride_avg_pool,
221
+ drop_prob,
222
+ activation: nn.Module = nn.ELU,
217
223
  ):
218
224
  super().__init__()
219
225
 
220
226
  self.shallownet = nn.Sequential(
221
- nn.Conv2d(1, n_filters_time,
222
- (1, filter_time_length), (1, 1)),
223
- nn.Conv2d(n_filters_time, n_filters_time,
224
- (n_channels, 1), (1, 1)),
227
+ nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
228
+ nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
225
229
  nn.BatchNorm2d(num_features=n_filters_time),
226
- nn.ELU(),
230
+ activation(),
227
231
  nn.AvgPool2d(
228
- kernel_size=(1, pool_time_length),
229
- stride=(1, stride_avg_pool)
232
+ kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
230
233
  ),
231
234
  # pooling acts as slicing to obtain 'patch' along the
232
235
  # time dimension as in ViT
@@ -246,79 +249,43 @@ class _PatchEmbedding(nn.Module):
246
249
  return x
247
250
 
248
251
 
249
- class _MultiHeadAttention(nn.Module):
250
- def __init__(self, emb_size, num_heads, dropout):
251
- super().__init__()
252
- self.emb_size = emb_size
253
- self.num_heads = num_heads
254
- self.keys = nn.Linear(emb_size, emb_size)
255
- self.queries = nn.Linear(emb_size, emb_size)
256
- self.values = nn.Linear(emb_size, emb_size)
257
- self.att_drop = nn.Dropout(dropout)
258
- self.projection = nn.Linear(emb_size, emb_size)
259
-
260
- def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
261
- queries = rearrange(
262
- self.queries(x), "b n (h d) -> b h n d", h=self.num_heads
263
- )
264
- keys = rearrange(
265
- self.keys(x), "b n (h d) -> b h n d", h=self.num_heads
266
- )
267
- values = rearrange(
268
- self.values(x), "b n (h d) -> b h n d", h=self.num_heads
269
- )
270
- energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
271
- if mask is not None:
272
- fill_value = torch.finfo(torch.float32).min
273
- energy.mask_fill(~mask, fill_value)
274
-
275
- scaling = self.emb_size ** (1 / 2)
276
- att = F.softmax(energy / scaling, dim=-1)
277
- att = self.att_drop(att)
278
- out = torch.einsum("bhal, bhlv -> bhav ", att, values)
279
- out = rearrange(out, "b h n d -> b n (h d)")
280
- out = self.projection(out)
281
- return out
282
-
283
-
284
252
  class _ResidualAdd(nn.Module):
285
253
  def __init__(self, fn):
286
254
  super().__init__()
287
255
  self.fn = fn
288
256
 
289
- def forward(self, x, **kwargs):
257
+ def forward(self, x):
290
258
  res = x
291
- x = self.fn(x, **kwargs)
259
+ x = self.fn(x)
292
260
  x += res
293
261
  return x
294
262
 
295
263
 
296
- class _FeedForwardBlock(nn.Sequential):
297
- def __init__(self, emb_size, expansion, drop_p):
298
- super().__init__(
299
- nn.Linear(emb_size, expansion * emb_size),
300
- nn.GELU(),
301
- nn.Dropout(drop_p),
302
- nn.Linear(expansion * emb_size, emb_size),
303
- )
304
-
305
-
306
264
  class _TransformerEncoderBlock(nn.Sequential):
307
- def __init__(self, emb_size, att_heads, att_drop, forward_expansion=4):
265
+ def __init__(
266
+ self,
267
+ emb_size,
268
+ att_heads,
269
+ att_drop,
270
+ forward_expansion=4,
271
+ activation: nn.Module = nn.GELU,
272
+ ):
308
273
  super().__init__(
309
274
  _ResidualAdd(
310
275
  nn.Sequential(
311
276
  nn.LayerNorm(emb_size),
312
- _MultiHeadAttention(emb_size, att_heads, att_drop),
277
+ MultiHeadAttention(emb_size, att_heads, att_drop),
313
278
  nn.Dropout(att_drop),
314
279
  )
315
280
  ),
316
281
  _ResidualAdd(
317
282
  nn.Sequential(
318
283
  nn.LayerNorm(emb_size),
319
- _FeedForwardBlock(
320
- emb_size, expansion=forward_expansion,
321
- drop_p=att_drop
284
+ FeedForwardBlock(
285
+ emb_size,
286
+ expansion=forward_expansion,
287
+ drop_p=att_drop,
288
+ activation=activation,
322
289
  ),
323
290
  nn.Dropout(att_drop),
324
291
  )
@@ -344,19 +311,29 @@ class _TransformerEncoder(nn.Sequential):
344
311
 
345
312
  """
346
313
 
347
- def __init__(self, att_depth, emb_size, att_heads, att_drop):
314
+ def __init__(
315
+ self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
316
+ ):
348
317
  super().__init__(
349
318
  *[
350
- _TransformerEncoderBlock(emb_size, att_heads, att_drop)
319
+ _TransformerEncoderBlock(
320
+ emb_size, att_heads, att_drop, activation=activation
321
+ )
351
322
  for _ in range(att_depth)
352
323
  ]
353
324
  )
354
325
 
355
326
 
356
327
  class _FullyConnected(nn.Module):
357
- def __init__(self, final_fc_length,
358
- drop_prob_1=0.5, drop_prob_2=0.3, out_channels=256,
359
- hidden_channels=32):
328
+ def __init__(
329
+ self,
330
+ final_fc_length,
331
+ drop_prob_1=0.5,
332
+ drop_prob_2=0.3,
333
+ out_channels=256,
334
+ hidden_channels=32,
335
+ activation: nn.Module = nn.ELU,
336
+ ):
360
337
  """Fully-connected layer for the transformer encoder.
361
338
 
362
339
  Parameters
@@ -375,17 +352,16 @@ class _FullyConnected(nn.Module):
375
352
  Number of output channels for the second linear layer.
376
353
  return_features : bool
377
354
  Whether to return input features.
378
- add_log_softmax: bool
379
- Whether to add LogSoftmax non-linearity as the final layer.
380
355
  """
381
356
 
382
357
  super().__init__()
358
+ self.hidden_channels = hidden_channels
383
359
  self.fc = nn.Sequential(
384
360
  nn.Linear(final_fc_length, out_channels),
385
- nn.ELU(),
361
+ activation(),
386
362
  nn.Dropout(drop_prob_1),
387
363
  nn.Linear(out_channels, hidden_channels),
388
- nn.ELU(),
364
+ activation(),
389
365
  nn.Dropout(drop_prob_2),
390
366
  )
391
367
 
@@ -393,40 +369,3 @@ class _FullyConnected(nn.Module):
393
369
  x = x.contiguous().view(x.size(0), -1)
394
370
  out = self.fc(x)
395
371
  return out
396
-
397
-
398
- class _FinalLayer(nn.Module):
399
- def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
400
- """Classification head for the transformer encoder.
401
-
402
- Parameters
403
- ----------
404
- n_classes : int
405
- Number of classes for classification.
406
- hidden_channels : int
407
- Number of output channels for the second linear layer.
408
- return_features : bool
409
- Whether to return input features.
410
- add_log_softmax : bool
411
- Adding LogSoftmax or not.
412
- """
413
-
414
- super().__init__()
415
- self.final_layer = nn.Sequential(
416
- nn.Linear(hidden_channels, n_classes),
417
- )
418
- self.return_features = return_features
419
- if add_log_softmax:
420
- classification = nn.LogSoftmax(dim=1)
421
- else:
422
- classification = nn.Identity()
423
- if not self.return_features:
424
- self.final_layer.add_module("classification", classification)
425
-
426
- def forward(self, x):
427
- if self.return_features:
428
- out = self.final_layer(x)
429
- return out, x
430
- else:
431
- out = self.final_layer(x)
432
- return out