braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,94 +2,23 @@
2
2
  #
3
3
  # License: BSD-3
4
4
  import torch
5
- from torch import nn
6
5
  from einops.layers.torch import Rearrange
6
+ from torch import nn
7
7
 
8
- from .modules import Ensure4d
9
- from .base import EEGModuleMixin, deprecated_args
10
-
11
-
12
- class _DepthwiseConv2d(torch.nn.Conv2d):
13
- def __init__(
14
- self,
15
- in_channels,
16
- depth_multiplier=2,
17
- kernel_size=3,
18
- stride=1,
19
- padding=0,
20
- dilation=1,
21
- bias=True,
22
- padding_mode="zeros",
23
- ):
24
- out_channels = in_channels * depth_multiplier
25
- super().__init__(
26
- in_channels=in_channels,
27
- out_channels=out_channels,
28
- kernel_size=kernel_size,
29
- stride=stride,
30
- padding=padding,
31
- dilation=dilation,
32
- groups=in_channels,
33
- bias=bias,
34
- padding_mode=padding_mode,
35
- )
36
-
37
-
38
- class _InceptionBlock(nn.Module):
39
- def __init__(self, branches):
40
- super().__init__()
41
- self.branches = nn.ModuleList(branches)
42
-
43
- def forward(self, x):
44
- return torch.cat([branch(x) for branch in self.branches], 1)
45
-
8
+ from braindecode.models.base import EEGModuleMixin
9
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
46
10
 
47
- class _TCBlock(nn.Module):
48
- def __init__(self, in_ch, kernel_length, dialation, padding, drop_prob=0.4):
49
- super().__init__()
50
- self.pad = padding
51
- self.tc1 = nn.Sequential(
52
- _DepthwiseConv2d(
53
- in_ch,
54
- kernel_size=(1, kernel_length),
55
- depth_multiplier=1,
56
- dilation=(1, dialation),
57
- bias=False,
58
- padding="valid",
59
- ),
60
- nn.BatchNorm2d(in_ch),
61
- nn.ELU(),
62
- nn.Dropout(drop_prob),
63
- )
64
-
65
- self.tc2 = nn.Sequential(
66
- _DepthwiseConv2d(
67
- in_ch,
68
- kernel_size=(1, kernel_length),
69
- depth_multiplier=1,
70
- dilation=(1, dialation),
71
- bias=False,
72
- padding="valid",
73
- ),
74
- nn.BatchNorm2d(in_ch),
75
- nn.ELU(),
76
- nn.Dropout(drop_prob),
77
- )
78
11
 
79
- def forward(self, x):
80
- residual = x
81
- paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
82
- x = nn.functional.pad(x, paddings)
83
- x = self.tc1(x)
84
- x = nn.functional.pad(x, paddings)
85
- x = self.tc2(x) + residual
86
- return x
12
+ class EEGITNet(EEGModuleMixin, nn.Sequential):
13
+ """EEG-ITNet from Salami, et al (2022) [Salami2022]_
87
14
 
15
+ .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
16
+ :align: center
17
+ :alt: EEG-ITNet Architecture
88
18
 
89
- class EEGITNet(EEGModuleMixin, nn.Sequential):
90
- """EEG-ITNet: An Explainable Inception Temporal
91
- Convolutional Network for motor imagery classification from
92
- Salami et. al 2022.
19
+ EEG-ITNet: An Explainable Inception Temporal
20
+ Convolutional Network for motor imagery classification from
21
+ Salami et al. 2022.
93
22
 
94
23
  See [Salami2022]_ for details.
95
24
 
@@ -99,45 +28,62 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
99
28
  ----------
100
29
  drop_prob: float
101
30
  Dropout probability.
102
- n_classes: int
103
- Alias for n_outputs.
104
- in_channels: int
105
- Alias for n_chans.
106
- input_window_samples : int
107
- Alias for n_times.
108
-
109
- References
110
- ----------
111
- .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet: An Explainable
112
- Inception Temporal Convolutional Network for motor imagery classification," in IEEE Access,
113
- doi: 10.1109/ACCESS.2022.3161489.
31
+ activation: nn.Module, default=nn.ELU
32
+ Activation function class to apply. Should be a PyTorch activation
33
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
34
+ kernel_length : int, optional
35
+ Kernel length for inception branches. Determines the temporal receptive field.
36
+ Default is 16.
37
+ pool_kernel : int, optional
38
+ Pooling kernel size for the average pooling layer. Default is 4.
39
+ tcn_in_channel : int, optional
40
+ Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
41
+ tcn_kernel_size : int, optional
42
+ Kernel size for the TC blocks. Determines the temporal receptive field.
43
+ Default is 4.
44
+ tcn_padding : int, optional
45
+ Padding size for the TC blocks to maintain the input dimensions. Default is 3.
46
+ drop_prob : float, optional
47
+ Dropout probability applied after certain layers to prevent overfitting.
48
+ Default is 0.4.
49
+ tcn_dilatation : int, optional
50
+ Dilation rate for the first TC block. Subsequent blocks will have
51
+ dilation rates multiplied by powers of 2. Default is 1.
114
52
 
115
53
  Notes
116
54
  -----
117
55
  This implementation is not guaranteed to be correct, has not been checked
118
56
  by original authors, only reimplemented from the paper based on author implementation.
57
+
58
+
59
+ References
60
+ ----------
61
+ .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
62
+ An Explainable Inception Temporal Convolutional Network for motor
63
+ imagery classification," in IEEE Access,
64
+ doi: 10.1109/ACCESS.2022.3161489.
119
65
  """
120
66
 
121
67
  def __init__(
122
- self,
123
- n_outputs=None,
124
- n_chans=None,
125
- n_times=None,
126
- drop_prob=0.4,
127
- chs_info=None,
128
- input_window_seconds=None,
129
- sfreq=None,
130
- n_classes=None,
131
- in_channels=None,
132
- input_window_samples=None,
133
- add_log_softmax=True,
68
+ self,
69
+ # Braindecode parameters
70
+ n_outputs=None,
71
+ n_chans=None,
72
+ n_times=None,
73
+ chs_info=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ # Model parameters
77
+ n_filters_time: int = 2,
78
+ kernel_length: int = 16,
79
+ pool_kernel: int = 4,
80
+ tcn_in_channel: int = 14,
81
+ tcn_kernel_size: int = 4,
82
+ tcn_padding: int = 3,
83
+ drop_prob: float = 0.4,
84
+ tcn_dilatation: int = 1,
85
+ activation: nn.Module = nn.ELU,
134
86
  ):
135
- n_outputs, n_chans, n_times, = deprecated_args(
136
- self,
137
- ('n_classes', 'n_outputs', n_classes, n_outputs),
138
- ('in_channels', 'n_chans', in_channels, n_chans),
139
- ('input_window_samples', 'n_times', input_window_samples, n_times),
140
- )
141
87
  super().__init__(
142
88
  n_outputs=n_outputs,
143
89
  n_chans=n_chans,
@@ -145,88 +91,131 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
145
91
  n_times=n_times,
146
92
  input_window_seconds=input_window_seconds,
147
93
  sfreq=sfreq,
148
- add_log_softmax=add_log_softmax,
149
94
  )
150
95
  self.mapping = {
151
- 'classification.1.weight': 'final_layer.clf.weight',
152
- 'classification.1.bias': 'final_layer.clf.weight'}
96
+ "classification.1.weight": "final_layer.clf.weight",
97
+ "classification.1.bias": "final_layer.clf.weight",
98
+ }
153
99
 
154
100
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
155
- del n_classes, in_channels, input_window_samples
156
101
 
157
102
  # ======== Handling EEG input ========================
158
103
  self.add_module(
159
- "input_preprocess", nn.Sequential(Ensure4d(),
160
- Rearrange(
161
- "ba ch t 1 -> ba 1 ch t"))
104
+ "input_preprocess",
105
+ nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
162
106
  )
163
107
  # ======== Inception branches ========================
164
108
  block11 = self._get_inception_branch(
165
- in_channels=self.n_chans, out_channels=2, kernel_length=16
109
+ in_channels=self.n_chans,
110
+ out_channels=n_filters_time,
111
+ kernel_length=kernel_length,
112
+ activation=activation,
166
113
  )
167
114
  block12 = self._get_inception_branch(
168
- in_channels=self.n_chans, out_channels=4, kernel_length=32
115
+ in_channels=self.n_chans,
116
+ out_channels=n_filters_time * 2,
117
+ kernel_length=kernel_length * 2,
118
+ activation=activation,
169
119
  )
170
120
  block13 = self._get_inception_branch(
171
- in_channels=self.n_chans, out_channels=8, kernel_length=64
121
+ in_channels=self.n_chans,
122
+ out_channels=n_filters_time * 4,
123
+ kernel_length=n_filters_time * 4,
124
+ activation=activation,
125
+ )
126
+ self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
127
+ self.pool1 = self.add_module(
128
+ "pooling",
129
+ nn.Sequential(
130
+ nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
131
+ ),
172
132
  )
173
- self.add_module("inception_block", _InceptionBlock((block11, block12, block13)))
174
- self.pool1 = self.add_module("pooling", nn.Sequential(
175
- nn.AvgPool2d(kernel_size=(1, 4)),
176
- nn.Dropout(drop_prob)))
177
133
  # =========== TC blocks =====================
178
134
  self.add_module(
179
135
  "TC_block1",
180
- _TCBlock(in_ch=14, kernel_length=4, dialation=1, padding=3, drop_prob=drop_prob)
136
+ _TCBlock(
137
+ in_ch=tcn_in_channel,
138
+ kernel_length=tcn_kernel_size,
139
+ dilatation=tcn_dilatation,
140
+ padding=tcn_padding,
141
+ drop_prob=drop_prob,
142
+ activation=activation,
143
+ ),
181
144
  )
182
145
  # ================================
183
146
  self.add_module(
184
147
  "TC_block2",
185
- _TCBlock(in_ch=14, kernel_length=4, dialation=2, padding=6, drop_prob=drop_prob)
148
+ _TCBlock(
149
+ in_ch=tcn_in_channel,
150
+ kernel_length=tcn_kernel_size,
151
+ dilatation=tcn_dilatation * 2,
152
+ padding=tcn_padding * 2,
153
+ drop_prob=drop_prob,
154
+ activation=activation,
155
+ ),
186
156
  )
187
157
  # ================================
188
158
  self.add_module(
189
159
  "TC_block3",
190
- _TCBlock(in_ch=14, kernel_length=4, dialation=4, padding=12, drop_prob=drop_prob)
160
+ _TCBlock(
161
+ in_ch=tcn_in_channel,
162
+ kernel_length=tcn_kernel_size,
163
+ dilatation=tcn_dilatation * 4,
164
+ padding=tcn_padding * 4,
165
+ drop_prob=drop_prob,
166
+ activation=activation,
167
+ ),
191
168
  )
192
169
  # ================================
193
170
  self.add_module(
194
171
  "TC_block4",
195
- _TCBlock(in_ch=14, kernel_length=4, dialation=8, padding=24, drop_prob=drop_prob)
172
+ _TCBlock(
173
+ in_ch=tcn_in_channel,
174
+ kernel_length=tcn_kernel_size,
175
+ dilatation=tcn_dilatation * 8,
176
+ padding=tcn_padding * 8,
177
+ drop_prob=drop_prob,
178
+ activation=activation,
179
+ ),
196
180
  )
197
181
 
198
182
  # ============= Dimensionality reduction ===================
199
- self.add_module("dim_reduction", nn.Sequential(
200
- nn.Conv2d(14, 28, kernel_size=(1, 1)),
201
- nn.BatchNorm2d(28),
202
- nn.ELU(),
203
- nn.AvgPool2d((1, 4)),
204
- nn.Dropout(drop_prob)))
183
+ self.add_module(
184
+ "dim_reduction",
185
+ nn.Sequential(
186
+ nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
187
+ nn.BatchNorm2d(tcn_in_channel * 2),
188
+ activation(),
189
+ nn.AvgPool2d((1, tcn_kernel_size)),
190
+ nn.Dropout(drop_prob),
191
+ ),
192
+ )
205
193
  # ============== Classifier ==================
206
194
  # Moved flatten to another layer
207
195
  self.add_module("flatten", nn.Flatten())
208
196
 
209
- # Incorporating classification module and subsequent ones in one final layer
210
- module = nn.Sequential()
211
-
212
- module.add_module("clf",
213
- nn.Linear(int(int(self.n_times / 4) / 4) * 28, self.n_outputs))
197
+ num_features = self.get_output_shape()[-1]
214
198
 
215
- if self.add_log_softmax:
216
- module.add_module("out_fun", nn.LogSoftmax(dim=1))
217
- else:
218
- module.add_module("out_fun", nn.Identity())
219
-
220
- self.add_module("final_layer", module)
199
+ self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
221
200
 
222
201
  @staticmethod
223
- def _get_inception_branch(in_channels, out_channels, kernel_length, depth_multiplier=1):
202
+ def _get_inception_branch(
203
+ in_channels,
204
+ out_channels,
205
+ kernel_length,
206
+ depth_multiplier=1,
207
+ activation: nn.Module = nn.ELU,
208
+ ):
224
209
  return nn.Sequential(
225
210
  nn.Conv2d(
226
- 1, out_channels, kernel_size=(1, kernel_length), padding="same", bias=False
211
+ 1,
212
+ out_channels,
213
+ kernel_size=(1, kernel_length),
214
+ padding="same",
215
+ bias=False,
227
216
  ),
228
217
  nn.BatchNorm2d(out_channels),
229
- _DepthwiseConv2d(
218
+ DepthwiseConv2d(
230
219
  out_channels,
231
220
  kernel_size=(in_channels, 1),
232
221
  depth_multiplier=depth_multiplier,
@@ -234,4 +223,79 @@ class EEGITNet(EEGModuleMixin, nn.Sequential):
234
223
  padding="valid",
235
224
  ),
236
225
  nn.BatchNorm2d(out_channels),
237
- nn.ELU())
226
+ activation(),
227
+ )
228
+
229
+
230
+ class _TCBlock(nn.Module):
231
+ """
232
+ Temporal Convolutional (TC) block.
233
+
234
+ This module applies two depthwise separable convolutions with dilation and residual
235
+ connections, commonly used in temporal convolutional networks to capture long-range
236
+ dependencies in time-series data.
237
+
238
+ Parameters
239
+ ----------
240
+ in_ch : int
241
+ Number of input channels.
242
+ kernel_length : int
243
+ Length of the convolutional kernels.
244
+ dilatation : int
245
+ Dilatation rate for the convolutions.
246
+ padding : int
247
+ Amount of padding to add to the input.
248
+ drop_prob : float, optional
249
+ Dropout probability. Default is 0.4.
250
+ activation : nn.Module class, optional
251
+ Activation function class to use. Should be a PyTorch activation module class
252
+ like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ in_ch,
258
+ kernel_length,
259
+ dilatation,
260
+ padding,
261
+ drop_prob=0.4,
262
+ activation: nn.Module = nn.ELU,
263
+ ):
264
+ super().__init__()
265
+ self.pad = padding
266
+ self.tc1 = nn.Sequential(
267
+ DepthwiseConv2d(
268
+ in_ch,
269
+ kernel_size=(1, kernel_length),
270
+ depth_multiplier=1,
271
+ dilation=(1, dilatation),
272
+ bias=False,
273
+ padding="valid",
274
+ ),
275
+ nn.BatchNorm2d(in_ch),
276
+ activation(),
277
+ nn.Dropout(drop_prob),
278
+ )
279
+
280
+ self.tc2 = nn.Sequential(
281
+ DepthwiseConv2d(
282
+ in_ch,
283
+ kernel_size=(1, kernel_length),
284
+ depth_multiplier=1,
285
+ dilation=(1, dilatation),
286
+ bias=False,
287
+ padding="valid",
288
+ ),
289
+ nn.BatchNorm2d(in_ch),
290
+ activation(),
291
+ nn.Dropout(drop_prob),
292
+ )
293
+
294
+ def forward(self, x):
295
+ residual = x
296
+ paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
297
+ x = nn.functional.pad(x, paddings)
298
+ x = self.tc1(x)
299
+ x = nn.functional.pad(x, paddings)
300
+ x = self.tc2(x) + residual
301
+ return x