braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,95 +1,23 @@
1
1
  # Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
2
2
  #
3
3
  # License: BSD-3
4
- import torch
5
- from torch import nn
6
4
  from einops.layers.torch import Rearrange
5
+ from torch import nn
7
6
 
8
- from .modules import Ensure4d
9
- from .base import EEGModuleMixin, deprecated_args
10
-
11
-
12
- class _DepthwiseConv2d(torch.nn.Conv2d):
13
- def __init__(
14
- self,
15
- in_channels,
16
- depth_multiplier=2,
17
- kernel_size=3,
18
- stride=1,
19
- padding=0,
20
- dilation=1,
21
- bias=True,
22
- padding_mode="zeros",
23
- ):
24
- out_channels = in_channels * depth_multiplier
25
- super().__init__(
26
- in_channels=in_channels,
27
- out_channels=out_channels,
28
- kernel_size=kernel_size,
29
- stride=stride,
30
- padding=padding,
31
- dilation=dilation,
32
- groups=in_channels,
33
- bias=bias,
34
- padding_mode=padding_mode,
35
- )
36
-
37
-
38
- class _InceptionBlock(nn.Module):
39
- def __init__(self, branches):
40
- super().__init__()
41
- self.branches = nn.ModuleList(branches)
42
-
43
- def forward(self, x):
44
- return torch.cat([branch(x) for branch in self.branches], 1)
45
-
7
+ from braindecode.models.base import EEGModuleMixin
8
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
46
9
 
47
- class _TCBlock(nn.Module):
48
- def __init__(self, in_ch, kernel_length, dialation, padding, drop_prob=0.4):
49
- super().__init__()
50
- self.pad = padding
51
- self.tc1 = nn.Sequential(
52
- _DepthwiseConv2d(
53
- in_ch,
54
- kernel_size=(1, kernel_length),
55
- depth_multiplier=1,
56
- dilation=(1, dialation),
57
- bias=False,
58
- padding="valid",
59
- ),
60
- nn.BatchNorm2d(in_ch),
61
- nn.ELU(),
62
- nn.Dropout(drop_prob),
63
- )
64
-
65
- self.tc2 = nn.Sequential(
66
- _DepthwiseConv2d(
67
- in_ch,
68
- kernel_size=(1, kernel_length),
69
- depth_multiplier=1,
70
- dilation=(1, dialation),
71
- bias=False,
72
- padding="valid",
73
- ),
74
- nn.BatchNorm2d(in_ch),
75
- nn.ELU(),
76
- nn.Dropout(drop_prob),
77
- )
78
10
 
79
- def forward(self, x):
80
- residual = x
81
- paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
82
- x = nn.functional.pad(x, paddings)
83
- x = self.tc1(x)
84
- x = nn.functional.pad(x, paddings)
85
- x = self.tc2(x) + residual
86
- return x
11
+ class EEGITNet(EEGModuleMixin, nn.Sequential):
12
+ """EEG-ITNet from Salami, et al (2022) [Salami2022]_
87
13
 
14
+ .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
15
+ :align: center
16
+ :alt: EEG-ITNet Architecture
88
17
 
89
- class EEGITNet(EEGModuleMixin, nn.Sequential):
90
- """EEG-ITNet: An Explainable Inception Temporal
91
- Convolutional Network for motor imagery classification from
92
- Salami et. al 2022.
18
+ EEG-ITNet: An Explainable Inception Temporal
19
+ Convolutional Network for motor imagery classification from
20
+ Salami et al. 2022.
93
21
 
94
22
  See [Salami2022]_ for details.
95
23
 
@@ -99,45 +27,62 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
99
27
  ----------
100
28
  drop_prob: float
101
29
  Dropout probability.
102
- n_classes: int
103
- Alias for n_outputs.
104
- in_channels: int
105
- Alias for n_chans.
106
- input_window_samples : int
107
- Alias for n_times.
108
-
109
- References
110
- ----------
111
- .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet: An Explainable
112
- Inception Temporal Convolutional Network for motor imagery classification," in IEEE Access,
113
- doi: 10.1109/ACCESS.2022.3161489.
30
+ activation: nn.Module, default=nn.ELU
31
+ Activation function class to apply. Should be a PyTorch activation
32
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
33
+ kernel_length : int, optional
34
+ Kernel length for inception branches. Determines the temporal receptive field.
35
+ Default is 16.
36
+ pool_kernel : int, optional
37
+ Pooling kernel size for the average pooling layer. Default is 4.
38
+ tcn_in_channel : int, optional
39
+ Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
40
+ tcn_kernel_size : int, optional
41
+ Kernel size for the TC blocks. Determines the temporal receptive field.
42
+ Default is 4.
43
+ tcn_padding : int, optional
44
+ Padding size for the TC blocks to maintain the input dimensions. Default is 3.
45
+ drop_prob : float, optional
46
+ Dropout probability applied after certain layers to prevent overfitting.
47
+ Default is 0.4.
48
+ tcn_dilatation : int, optional
49
+ Dilation rate for the first TC block. Subsequent blocks will have
50
+ dilation rates multiplied by powers of 2. Default is 1.
114
51
 
115
52
  Notes
116
53
  -----
117
54
  This implementation is not guaranteed to be correct, has not been checked
118
55
  by original authors, only reimplemented from the paper based on author implementation.
56
+
57
+
58
+ References
59
+ ----------
60
+ .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
61
+ An Explainable Inception Temporal Convolutional Network for motor
62
+ imagery classification," in IEEE Access,
63
+ doi: 10.1109/ACCESS.2022.3161489.
119
64
  """
120
65
 
121
66
  def __init__(
122
- self,
123
- n_outputs=None,
124
- n_chans=None,
125
- n_times=None,
126
- drop_prob=0.4,
127
- chs_info=None,
128
- input_window_seconds=None,
129
- sfreq=None,
130
- n_classes=None,
131
- in_channels=None,
132
- input_window_samples=None,
133
- add_log_softmax=True,
67
+ self,
68
+ # Braindecode parameters
69
+ n_outputs=None,
70
+ n_chans=None,
71
+ n_times=None,
72
+ chs_info=None,
73
+ input_window_seconds=None,
74
+ sfreq=None,
75
+ # Model parameters
76
+ n_filters_time: int = 2,
77
+ kernel_length: int = 16,
78
+ pool_kernel: int = 4,
79
+ tcn_in_channel: int = 14,
80
+ tcn_kernel_size: int = 4,
81
+ tcn_padding: int = 3,
82
+ drop_prob: float = 0.4,
83
+ tcn_dilatation: int = 1,
84
+ activation: nn.Module = nn.ELU,
134
85
  ):
135
- n_outputs, n_chans, n_times, = deprecated_args(
136
- self,
137
- ('n_classes', 'n_outputs', n_classes, n_outputs),
138
- ('in_channels', 'n_chans', in_channels, n_chans),
139
- ('input_window_samples', 'n_times', input_window_samples, n_times),
140
- )
141
86
  super().__init__(
142
87
  n_outputs=n_outputs,
143
88
  n_chans=n_chans,
@@ -145,88 +90,131 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
145
90
  n_times=n_times,
146
91
  input_window_seconds=input_window_seconds,
147
92
  sfreq=sfreq,
148
- add_log_softmax=add_log_softmax,
149
93
  )
150
94
  self.mapping = {
151
- 'classification.1.weight': 'final_layer.clf.weight',
152
- 'classification.1.bias': 'final_layer.clf.weight'}
95
+ "classification.1.weight": "final_layer.clf.weight",
96
+ "classification.1.bias": "final_layer.clf.weight",
97
+ }
153
98
 
154
99
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
155
- del n_classes, in_channels, input_window_samples
156
100
 
157
101
  # ======== Handling EEG input ========================
158
102
  self.add_module(
159
- "input_preprocess", nn.Sequential(Ensure4d(),
160
- Rearrange(
161
- "ba ch t 1 -> ba 1 ch t"))
103
+ "input_preprocess",
104
+ nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
162
105
  )
163
106
  # ======== Inception branches ========================
164
107
  block11 = self._get_inception_branch(
165
- in_channels=self.n_chans, out_channels=2, kernel_length=16
108
+ in_channels=self.n_chans,
109
+ out_channels=n_filters_time,
110
+ kernel_length=kernel_length,
111
+ activation=activation,
166
112
  )
167
113
  block12 = self._get_inception_branch(
168
- in_channels=self.n_chans, out_channels=4, kernel_length=32
114
+ in_channels=self.n_chans,
115
+ out_channels=n_filters_time * 2,
116
+ kernel_length=kernel_length * 2,
117
+ activation=activation,
169
118
  )
170
119
  block13 = self._get_inception_branch(
171
- in_channels=self.n_chans, out_channels=8, kernel_length=64
120
+ in_channels=self.n_chans,
121
+ out_channels=n_filters_time * 4,
122
+ kernel_length=n_filters_time * 4,
123
+ activation=activation,
124
+ )
125
+ self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
126
+ self.pool1 = self.add_module(
127
+ "pooling",
128
+ nn.Sequential(
129
+ nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
130
+ ),
172
131
  )
173
- self.add_module("inception_block", _InceptionBlock((block11, block12, block13)))
174
- self.pool1 = self.add_module("pooling", nn.Sequential(
175
- nn.AvgPool2d(kernel_size=(1, 4)),
176
- nn.Dropout(drop_prob)))
177
132
  # =========== TC blocks =====================
178
133
  self.add_module(
179
134
  "TC_block1",
180
- _TCBlock(in_ch=14, kernel_length=4, dialation=1, padding=3, drop_prob=drop_prob)
135
+ _TCBlock(
136
+ in_ch=tcn_in_channel,
137
+ kernel_length=tcn_kernel_size,
138
+ dilatation=tcn_dilatation,
139
+ padding=tcn_padding,
140
+ drop_prob=drop_prob,
141
+ activation=activation,
142
+ ),
181
143
  )
182
144
  # ================================
183
145
  self.add_module(
184
146
  "TC_block2",
185
- _TCBlock(in_ch=14, kernel_length=4, dialation=2, padding=6, drop_prob=drop_prob)
147
+ _TCBlock(
148
+ in_ch=tcn_in_channel,
149
+ kernel_length=tcn_kernel_size,
150
+ dilatation=tcn_dilatation * 2,
151
+ padding=tcn_padding * 2,
152
+ drop_prob=drop_prob,
153
+ activation=activation,
154
+ ),
186
155
  )
187
156
  # ================================
188
157
  self.add_module(
189
158
  "TC_block3",
190
- _TCBlock(in_ch=14, kernel_length=4, dialation=4, padding=12, drop_prob=drop_prob)
159
+ _TCBlock(
160
+ in_ch=tcn_in_channel,
161
+ kernel_length=tcn_kernel_size,
162
+ dilatation=tcn_dilatation * 4,
163
+ padding=tcn_padding * 4,
164
+ drop_prob=drop_prob,
165
+ activation=activation,
166
+ ),
191
167
  )
192
168
  # ================================
193
169
  self.add_module(
194
170
  "TC_block4",
195
- _TCBlock(in_ch=14, kernel_length=4, dialation=8, padding=24, drop_prob=drop_prob)
171
+ _TCBlock(
172
+ in_ch=tcn_in_channel,
173
+ kernel_length=tcn_kernel_size,
174
+ dilatation=tcn_dilatation * 8,
175
+ padding=tcn_padding * 8,
176
+ drop_prob=drop_prob,
177
+ activation=activation,
178
+ ),
196
179
  )
197
180
 
198
181
  # ============= Dimensionality reduction ===================
199
- self.add_module("dim_reduction", nn.Sequential(
200
- nn.Conv2d(14, 28, kernel_size=(1, 1)),
201
- nn.BatchNorm2d(28),
202
- nn.ELU(),
203
- nn.AvgPool2d((1, 4)),
204
- nn.Dropout(drop_prob)))
182
+ self.add_module(
183
+ "dim_reduction",
184
+ nn.Sequential(
185
+ nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
186
+ nn.BatchNorm2d(tcn_in_channel * 2),
187
+ activation(),
188
+ nn.AvgPool2d((1, tcn_kernel_size)),
189
+ nn.Dropout(drop_prob),
190
+ ),
191
+ )
205
192
  # ============== Classifier ==================
206
193
  # Moved flatten to another layer
207
194
  self.add_module("flatten", nn.Flatten())
208
195
 
209
- # Incorporating classification module and subsequent ones in one final layer
210
- module = nn.Sequential()
211
-
212
- module.add_module("clf",
213
- nn.Linear(int(int(self.n_times / 4) / 4) * 28, self.n_outputs))
196
+ num_features = self.get_output_shape()[-1]
214
197
 
215
- if self.add_log_softmax:
216
- module.add_module("out_fun", nn.LogSoftmax(dim=1))
217
- else:
218
- module.add_module("out_fun", nn.Identity())
219
-
220
- self.add_module("final_layer", module)
198
+ self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
221
199
 
222
200
  @staticmethod
223
- def _get_inception_branch(in_channels, out_channels, kernel_length, depth_multiplier=1):
201
+ def _get_inception_branch(
202
+ in_channels,
203
+ out_channels,
204
+ kernel_length,
205
+ depth_multiplier=1,
206
+ activation: nn.Module = nn.ELU,
207
+ ):
224
208
  return nn.Sequential(
225
209
  nn.Conv2d(
226
- 1, out_channels, kernel_size=(1, kernel_length), padding="same", bias=False
210
+ 1,
211
+ out_channels,
212
+ kernel_size=(1, kernel_length),
213
+ padding="same",
214
+ bias=False,
227
215
  ),
228
216
  nn.BatchNorm2d(out_channels),
229
- _DepthwiseConv2d(
217
+ DepthwiseConv2d(
230
218
  out_channels,
231
219
  kernel_size=(in_channels, 1),
232
220
  depth_multiplier=depth_multiplier,
@@ -234,4 +222,79 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
234
222
  padding="valid",
235
223
  ),
236
224
  nn.BatchNorm2d(out_channels),
237
- nn.ELU())
225
+ activation(),
226
+ )
227
+
228
+
229
+ class _TCBlock(nn.Module):
230
+ """
231
+ Temporal Convolutional (TC) block.
232
+
233
+ This module applies two depthwise separable convolutions with dilation and residual
234
+ connections, commonly used in temporal convolutional networks to capture long-range
235
+ dependencies in time-series data.
236
+
237
+ Parameters
238
+ ----------
239
+ in_ch : int
240
+ Number of input channels.
241
+ kernel_length : int
242
+ Length of the convolutional kernels.
243
+ dilatation : int
244
+ Dilatation rate for the convolutions.
245
+ padding : int
246
+ Amount of padding to add to the input.
247
+ drop_prob : float, optional
248
+ Dropout probability. Default is 0.4.
249
+ activation : nn.Module class, optional
250
+ Activation function class to use. Should be a PyTorch activation module class
251
+ like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_ch,
257
+ kernel_length,
258
+ dilatation,
259
+ padding,
260
+ drop_prob=0.4,
261
+ activation: nn.Module = nn.ELU,
262
+ ):
263
+ super().__init__()
264
+ self.pad = padding
265
+ self.tc1 = nn.Sequential(
266
+ DepthwiseConv2d(
267
+ in_ch,
268
+ kernel_size=(1, kernel_length),
269
+ depth_multiplier=1,
270
+ dilation=(1, dilatation),
271
+ bias=False,
272
+ padding="valid",
273
+ ),
274
+ nn.BatchNorm2d(in_ch),
275
+ activation(),
276
+ nn.Dropout(drop_prob),
277
+ )
278
+
279
+ self.tc2 = nn.Sequential(
280
+ DepthwiseConv2d(
281
+ in_ch,
282
+ kernel_size=(1, kernel_length),
283
+ depth_multiplier=1,
284
+ dilation=(1, dilatation),
285
+ bias=False,
286
+ padding="valid",
287
+ ),
288
+ nn.BatchNorm2d(in_ch),
289
+ activation(),
290
+ nn.Dropout(drop_prob),
291
+ )
292
+
293
+ def forward(self, x):
294
+ residual = x
295
+ paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
296
+ x = nn.functional.pad(x, paddings)
297
+ x = self.tc1(x)
298
+ x = nn.functional.pad(x, paddings)
299
+ x = self.tc2(x) + residual
300
+ return x