braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,15 +3,20 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
  import torch
6
- from torch import nn
7
6
  from einops.layers.torch import Rearrange
7
+ from torch import nn
8
8
 
9
- from .modules import Ensure4d
10
- from .base import EEGModuleMixin, deprecated_args
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
11
 
12
12
 
13
13
  class EEGInceptionMI(EEGModuleMixin, nn.Module):
14
- """EEG Inception for Motor Imagery, as proposed in [1]_
14
+ """EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
15
+
16
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
17
+ :align: center
18
+ :alt: EEGInceptionMI Architecture
19
+
15
20
 
16
21
  The model is strongly based on the original InceptionNet for computer
17
22
  vision. The main goal is to extract features in parallel with different
@@ -43,47 +48,31 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
43
48
  Size in seconds of the basic 1D convolutional kernel used in inception
44
49
  modules. Each convolutional layer in such modules have kernels of
45
50
  increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7,
46
- 0.9 here for `n_convs`=5). Defaults to 0.1 s.
51
+ 0.9 here for ``n_convs=5``). Defaults to 0.1 s.
47
52
  activation: nn.Module
48
53
  Activation function. Defaults to ReLU activation.
49
- in_channels : int
50
- Alias for `n_chans`.
51
- n_classes : int
52
- Alias for `n_outputs`.
53
- input_window_s : float, optional
54
- Alias for `input_window_seconds`.
55
54
 
56
55
  References
57
56
  ----------
58
57
  .. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021).
59
- EEG-inception: an accurate and robust end-to-end neural network
60
- for EEG-based motor imagery classification.
61
- Journal of Neural Engineering, 18(4), 046014.
58
+ EEG-inception: an accurate and robust end-to-end neural network
59
+ for EEG-based motor imagery classification.
60
+ Journal of Neural Engineering, 18(4), 046014.
62
61
  """
63
62
 
64
63
  def __init__(
65
- self,
66
- n_chans=None,
67
- n_outputs=None,
68
- input_window_seconds=4.5,
69
- sfreq=250,
70
- n_convs=5,
71
- n_filters=48,
72
- kernel_unit_s=0.1,
73
- activation=nn.ReLU(),
74
- chs_info=None,
75
- n_times=None,
76
- in_channels=None,
77
- n_classes=None,
78
- input_window_s=None,
79
- add_log_softmax=True,
64
+ self,
65
+ n_chans=None,
66
+ n_outputs=None,
67
+ input_window_seconds=None,
68
+ sfreq=250,
69
+ n_convs: int = 5,
70
+ n_filters: int = 48,
71
+ kernel_unit_s: float = 0.1,
72
+ activation: nn.Module = nn.ReLU,
73
+ chs_info=None,
74
+ n_times=None,
80
75
  ):
81
- n_chans, n_outputs, input_window_seconds, = deprecated_args(
82
- self,
83
- ('in_channels', 'n_chans', in_channels, n_chans),
84
- ('n_classes', 'n_outputs', n_classes, n_outputs),
85
- ('input_window_s', 'input_window_seconds', input_window_s, input_window_seconds),
86
- )
87
76
  super().__init__(
88
77
  n_outputs=n_outputs,
89
78
  n_chans=n_chans,
@@ -91,10 +80,8 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
91
80
  n_times=n_times,
92
81
  input_window_seconds=input_window_seconds,
93
82
  sfreq=sfreq,
94
- add_log_softmax=add_log_softmax,
95
83
  )
96
84
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
97
- del in_channels, n_classes, input_window_s
98
85
 
99
86
  self.n_convs = n_convs
100
87
  self.n_filters = n_filters
@@ -105,8 +92,9 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
105
92
  self.dimshuffle = Rearrange("batch C T 1 -> batch C 1 T")
106
93
 
107
94
  self.mapping = {
108
- 'fc.weight': 'final_layer.fc.weight',
109
- 'tc.bias': 'final_layer.fc.bias'}
95
+ "fc.weight": "final_layer.fc.weight",
96
+ "tc.bias": "final_layer.fc.bias",
97
+ }
110
98
 
111
99
  # ======== Inception branches ========================
112
100
 
@@ -121,16 +109,19 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
121
109
 
122
110
  intermediate_in_channels = (self.n_convs + 1) * self.n_filters
123
111
 
124
- self.intermediate_inception_modules_1 = nn.ModuleList([
125
- _InceptionModuleMI(
126
- in_channels=intermediate_in_channels,
127
- n_filters=self.n_filters,
128
- n_convs=self.n_convs,
129
- kernel_unit_s=self.kernel_unit_s,
130
- sfreq=self.sfreq,
131
- activation=self.activation,
132
- ) for _ in range(2)
133
- ])
112
+ self.intermediate_inception_modules_1 = nn.ModuleList(
113
+ [
114
+ _InceptionModuleMI(
115
+ in_channels=intermediate_in_channels,
116
+ n_filters=self.n_filters,
117
+ n_convs=self.n_convs,
118
+ kernel_unit_s=self.kernel_unit_s,
119
+ sfreq=self.sfreq,
120
+ activation=self.activation,
121
+ )
122
+ for _ in range(2)
123
+ ]
124
+ )
134
125
 
135
126
  self.residual_block_1 = _ResidualModuleMI(
136
127
  in_channels=self.n_chans,
@@ -138,16 +129,19 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
138
129
  activation=self.activation,
139
130
  )
140
131
 
141
- self.intermediate_inception_modules_2 = nn.ModuleList([
142
- _InceptionModuleMI(
143
- in_channels=intermediate_in_channels,
144
- n_filters=self.n_filters,
145
- n_convs=self.n_convs,
146
- kernel_unit_s=self.kernel_unit_s,
147
- sfreq=self.sfreq,
148
- activation=self.activation,
149
- ) for _ in range(3)
150
- ])
132
+ self.intermediate_inception_modules_2 = nn.ModuleList(
133
+ [
134
+ _InceptionModuleMI(
135
+ in_channels=intermediate_in_channels,
136
+ n_filters=self.n_filters,
137
+ n_convs=self.n_convs,
138
+ kernel_unit_s=self.kernel_unit_s,
139
+ sfreq=self.sfreq,
140
+ activation=self.activation,
141
+ )
142
+ for _ in range(3)
143
+ ]
144
+ )
151
145
 
152
146
  self.residual_block_2 = _ResidualModuleMI(
153
147
  in_channels=intermediate_in_channels,
@@ -172,19 +166,20 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
172
166
  self.flat = nn.Flatten()
173
167
 
174
168
  module = nn.Sequential()
175
- module.add_module('fc',
176
- nn.Linear(in_features=intermediate_in_channels,
177
- out_features=self.n_outputs,
178
- bias=True, ))
179
- if self.add_log_softmax:
180
- module.add_module('out_fun', nn.LogSoftmax(dim=1))
181
- else:
182
- module.add_module('out_fun', nn.Identity())
169
+ module.add_module(
170
+ "fc",
171
+ nn.Linear(
172
+ in_features=intermediate_in_channels,
173
+ out_features=self.n_outputs,
174
+ bias=True,
175
+ ),
176
+ )
177
+ module.add_module("out_fun", nn.Identity())
183
178
  self.final_layer = module
184
179
 
185
180
  def forward(
186
- self,
187
- X: torch.Tensor,
181
+ self,
182
+ X: torch.Tensor,
188
183
  ) -> torch.Tensor:
189
184
  X = self.ensuredims(X)
190
185
  X = self.dimshuffle(X)
@@ -211,14 +206,46 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
211
206
 
212
207
 
213
208
  class _InceptionModuleMI(nn.Module):
209
+ """
210
+ Inception module.
211
+
212
+ This module implements a inception-like architecture that processes input
213
+ feature maps through multiple convolutional paths with different kernel
214
+ sizes, allowing the network to capture features at multiple scales.
215
+ It includes bottleneck layers, convolutional layers, and pooling layers,
216
+ followed by batch normalization and an activation function.
217
+
218
+ Parameters
219
+ ----------
220
+ in_channels : int
221
+ Number of input channels in the input tensor.
222
+ n_filters : int
223
+ Number of filters (output channels) for each convolutional layer.
224
+ n_convs : int
225
+ Number of convolutional layers in the module (excluding bottleneck
226
+ and pooling paths).
227
+ kernel_unit_s : float, optional
228
+ Base size (in seconds) for the convolutional kernels. The actual kernel
229
+ size is computed as ``(2 * n_units + 1) * kernel_unit``, where ``n_units``
230
+ ranges from 0 to ``n_convs - 1``. Default is 0.1 seconds.
231
+ sfreq : float, optional
232
+ Sampling frequency of the input data, used to convert kernel sizes from
233
+ seconds to samples. Default is 250 Hz.
234
+ activation : nn.Module class, optional
235
+ Activation function class to apply after batch normalization. Should be
236
+ a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``.
237
+ Default is ``nn.ReLU``.
238
+
239
+ """
240
+
214
241
  def __init__(
215
- self,
216
- in_channels,
217
- n_filters,
218
- n_convs,
219
- kernel_unit_s=0.1,
220
- sfreq=250,
221
- activation=nn.ReLU(),
242
+ self,
243
+ in_channels,
244
+ n_filters,
245
+ n_convs,
246
+ kernel_unit_s=0.1,
247
+ sfreq=250,
248
+ activation: nn.Module = nn.ReLU,
222
249
  ):
223
250
  super().__init__()
224
251
  self.in_channels = in_channels
@@ -253,23 +280,26 @@ class _InceptionModuleMI(nn.Module):
253
280
  bias=True,
254
281
  )
255
282
 
256
- self.conv_list = nn.ModuleList([
257
- nn.Conv2d(
258
- in_channels=self.n_filters,
259
- out_channels=self.n_filters,
260
- kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
261
- padding="same",
262
- bias=True,
263
- ) for n_units in range(self.n_convs)
264
- ])
283
+ self.conv_list = nn.ModuleList(
284
+ [
285
+ nn.Conv2d(
286
+ in_channels=self.n_filters,
287
+ out_channels=self.n_filters,
288
+ kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
289
+ padding="same",
290
+ bias=True,
291
+ )
292
+ for n_units in range(self.n_convs)
293
+ ]
294
+ )
265
295
 
266
296
  self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1))
267
297
 
268
- self.activation = activation
298
+ self.activation = activation()
269
299
 
270
300
  def forward(
271
- self,
272
- X: torch.Tensor,
301
+ self,
302
+ X: torch.Tensor,
273
303
  ) -> torch.Tensor:
274
304
  X1 = self.bottleneck(X)
275
305
 
@@ -277,6 +307,12 @@ class _InceptionModuleMI(nn.Module):
277
307
 
278
308
  X2 = self.pooling(X)
279
309
  X2 = self.pooling_conv(X2)
310
+ # Get the target length from one of the conv branches
311
+ target_len = X1[0].shape[-1]
312
+
313
+ # Crop the pooling output if its length does not match
314
+ if X2.shape[-1] != target_len:
315
+ X2 = X2[..., :target_len]
280
316
 
281
317
  out = torch.cat(X1 + [X2], 1)
282
318
 
@@ -285,16 +321,31 @@ class _InceptionModuleMI(nn.Module):
285
321
 
286
322
 
287
323
  class _ResidualModuleMI(nn.Module):
288
- def __init__(
289
- self,
290
- in_channels,
291
- n_filters,
292
- activation=nn.ReLU()
293
- ):
324
+ """
325
+ Residual module.
326
+
327
+ This module performs a 1x1 convolution followed by batch normalization and an activation function.
328
+ It is designed to process input feature maps and produce transformed output feature maps, often used
329
+ in residual connections within neural network architectures.
330
+
331
+ Parameters
332
+ ----------
333
+ in_channels : int
334
+ Number of input channels in the input tensor.
335
+ n_filters : int
336
+ Number of filters (output channels) for the convolutional layer.
337
+ activation : nn.Module, optional
338
+ Activation function to apply after batch normalization. Should be an instance of a PyTorch
339
+ activation module (e.g., ``nn.ReLU()``, ``nn.ELU()``). Default is ``nn.ReLU()``.
340
+
341
+
342
+ """
343
+
344
+ def __init__(self, in_channels, n_filters, activation: nn.Module = nn.ReLU):
294
345
  super().__init__()
295
346
  self.in_channels = in_channels
296
347
  self.n_filters = n_filters
297
- self.activation = activation
348
+ self.activation = activation()
298
349
 
299
350
  self.bn = nn.BatchNorm2d(self.n_filters)
300
351
  self.conv = nn.Conv2d(
@@ -305,9 +356,22 @@ class _ResidualModuleMI(nn.Module):
305
356
  )
306
357
 
307
358
  def forward(
308
- self,
309
- X: torch.Tensor,
359
+ self,
360
+ X: torch.Tensor,
310
361
  ) -> torch.Tensor:
362
+ """
363
+ Forward pass of the residual module.
364
+
365
+ Parameters
366
+ ----------
367
+ X : torch.Tensor
368
+ Input tensor of shape (batch_size, ch_names, n_times).
369
+
370
+ Returns
371
+ -------
372
+ torch.Tensor
373
+ Output tensor after convolution, batch normalization, and activation function.
374
+ """
311
375
  out = self.conv(X)
312
376
  out = self.bn(out)
313
377
  return self.activation(out)