braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,21 +2,24 @@
2
2
  # Cedric Rommel <cedric.rommel@inria.fr>
3
3
  #
4
4
  # License: BSD (3-clause)
5
- from numpy import prod
5
+ import math
6
6
 
7
- from torch import nn
8
7
  from einops.layers.torch import Rearrange
8
+ from torch import nn
9
9
 
10
- from .modules import Ensure4d
11
- from .eegnet import _glorot_weight_zero_bias
12
- from .eegitnet import _InceptionBlock, _DepthwiseConv2d
13
- from .base import EEGModuleMixin, deprecated_args
10
+ from braindecode.functional import glorot_weight_zero_bias
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
14
13
 
15
14
 
16
15
  class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
17
- """EEG Inception for ERP-based classification
16
+ """EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
17
+
18
+ .. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
19
+ :align: center
20
+ :alt: EEGInceptionERP Architecture
18
21
 
19
- The code for the paper and this model is also available at [Santamaria2020]_
22
+ The code for the paper and this model is also available at [santamaria2020]_
20
23
  and an adaptation for PyTorch [2]_.
21
24
 
22
25
  The model is strongly based on the original InceptionNet for an image. The main goal is
@@ -27,9 +30,10 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
27
30
 
28
31
  One advantage of the EEG-Inception block is that it allows a network
29
32
  to learn simultaneous components of low and high frequency associated with the signal.
30
- The winners of BEETL Competition/NeurIps 2021 used parts of the model [beetl]_.
33
+ The winners of BEETL Competition/NeurIps 2021 used parts of the
34
+ model [beetl]_.
31
35
 
32
- The model is fully described in [Santamaria2020]_.
36
+ The model is fully described in [santamaria2020]_.
33
37
 
34
38
  Notes
35
39
  -----
@@ -40,85 +44,70 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
40
44
  ----------
41
45
  n_times : int, optional
42
46
  Size of the input, in number of samples. Set to 128 (1s) as in
43
- [Santamaria2020]_.
47
+ [santamaria2020]_.
44
48
  sfreq : float, optional
45
- EEG sampling frequency. Defaults to 128 as in [Santamaria2020]_.
49
+ EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
46
50
  drop_prob : float, optional
47
51
  Dropout rate inside all the network. Defaults to 0.5 as in
48
- [Santamaria2020]_.
52
+ [santamaria2020]_.
49
53
  scales_samples_s: list(float), optional
50
54
  Windows for inception block. Temporal scale (s) of the convolutions on
51
55
  each Inception module. This parameter determines the kernel sizes of
52
56
  the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
53
- [Santamaria2020]_.
57
+ [santamaria2020]_.
54
58
  n_filters : int, optional
55
59
  Initial number of convolutional filters. Defaults to 8 as in
56
- [Santamaria2020]_.
60
+ [santamaria2020]_.
57
61
  activation: nn.Module, optional
58
62
  Activation function. Defaults to ELU activation as in
59
- [Santamaria2020]_.
63
+ [santamaria2020]_.
60
64
  batch_norm_alpha: float, optional
61
65
  Momentum for BatchNorm2d. Defaults to 0.01.
62
66
  depth_multiplier: int, optional
63
67
  Depth multiplier for the depthwise convolution. Defaults to 2 as in
64
- [Santamaria2020]_.
68
+ [santamaria2020]_.
65
69
  pooling_sizes: list(int), optional
66
70
  Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
67
- in [Santamaria2020]_.
68
- in_channels : int
69
- Alias for n_chans.
70
- n_classes : int
71
- Alias for n_outputs.
72
- input_window_samples : int
73
- Alias for n_times.
71
+ in [santamaria2020]_.
72
+
74
73
 
75
74
  References
76
75
  ----------
77
- .. [Santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
76
+ .. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
78
77
  Vaquerizo-Villar, F., & Hornero, R. (2020).
79
78
  EEG-inception: A novel deep convolutional neural network for assistive
80
79
  ERP-based brain-computer interfaces.
81
80
  IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
82
81
  Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
83
- .. [2] Grifcc. Implementation of the EEGInception in torch (2022).
84
- Online: https://github.com/Grifcc/EEG/tree/90e412a407c5242dfc953d5ffb490bdb32faf022
85
- .. [beetl]_ Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
82
+ .. [2] Grifcc. Implementation of the EEGInception in torch (2022).
83
+ Online: https://github.com/Grifcc/EEG/
84
+ .. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
86
85
  Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
87
86
  Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
88
- Lawhern, V.J., Śliwowski, M., Rouanne, V. &amp; Tempczyk, P.. (2022).
87
+ Lawhern, V.J., Śliwowski, M., Rouanne, V. &amp; Tempczyk, P. (2022).
89
88
  2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &amp;
90
- Heterogeneous EEG Data Sets. <i>Proceedings of the NeurIPS 2021 Competitions and
91
- Demonstrations Track</i>, in <i>Proceedings of Machine Learning Research</i>
89
+ Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
90
+ Demonstrations Track, in Proceedings of Machine Learning Research
92
91
  176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
93
92
 
94
93
  """
95
94
 
96
95
  def __init__(
97
- self,
98
- n_chans=None,
99
- n_outputs=None,
100
- n_times=1000,
101
- sfreq=128,
102
- drop_prob=0.5,
103
- scales_samples_s=(0.5, 0.25, 0.125),
104
- n_filters=8,
105
- activation=nn.ELU(),
106
- batch_norm_alpha=0.01,
107
- depth_multiplier=2,
108
- pooling_sizes=(4, 2, 2, 2),
109
- chs_info=None,
110
- input_window_seconds=None,
111
- in_channels=None,
112
- n_classes=None,
113
- input_window_samples=None,
114
- add_log_softmax=True,
96
+ self,
97
+ n_chans=None,
98
+ n_outputs=None,
99
+ n_times=1000,
100
+ sfreq=128,
101
+ drop_prob=0.5,
102
+ scales_samples_s=(0.5, 0.25, 0.125),
103
+ n_filters=8,
104
+ activation: nn.Module = nn.ELU,
105
+ batch_norm_alpha=0.01,
106
+ depth_multiplier=2,
107
+ pooling_sizes=(4, 2, 2, 2),
108
+ chs_info=None,
109
+ input_window_seconds=None,
115
110
  ):
116
- n_chans, n_outputs, n_times, = deprecated_args(
117
- self,
118
- ('in_channels', 'n_chans', in_channels, n_chans),
119
- ('n_classes', 'n_outputs', n_classes, n_outputs),
120
- ('input_window_samples', 'n_times', input_window_samples, n_times),
121
- )
122
111
  super().__init__(
123
112
  n_outputs=n_outputs,
124
113
  n_chans=n_chans,
@@ -126,23 +115,23 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
126
115
  n_times=n_times,
127
116
  input_window_seconds=input_window_seconds,
128
117
  sfreq=sfreq,
129
- add_log_softmax=add_log_softmax
130
118
  )
131
119
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
132
- del in_channels, n_classes, input_window_samples
133
120
  self.drop_prob = drop_prob
134
121
  self.n_filters = n_filters
135
122
  self.scales_samples_s = scales_samples_s
136
123
  self.scales_samples = tuple(
137
- int(size_s * self.sfreq) for size_s in self.scales_samples_s)
124
+ int(size_s * self.sfreq) for size_s in self.scales_samples_s
125
+ )
138
126
  self.activation = activation
139
127
  self.alpha_momentum = batch_norm_alpha
140
128
  self.depth_multiplier = depth_multiplier
141
129
  self.pooling_sizes = pooling_sizes
142
130
 
143
131
  self.mapping = {
144
- 'classification.1.weight': 'final_layer.fc.weight',
145
- 'classification.1.bias': 'final_layer.fc.bias'}
132
+ "classification.1.weight": "final_layer.fc.weight",
133
+ "classification.1.bias": "final_layer.fc.bias",
134
+ }
146
135
 
147
136
  self.add_module("ensuredims", Ensure4d())
148
137
 
@@ -177,7 +166,9 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
177
166
  depth_multiplier=self.depth_multiplier,
178
167
  )
179
168
 
180
- self.add_module("inception_block_1", _InceptionBlock((block11, block12, block13)))
169
+ self.add_module(
170
+ "inception_block_1", InceptionBlock((block11, block12, block13))
171
+ )
181
172
 
182
173
  self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
183
174
 
@@ -190,7 +181,7 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
190
181
  kernel_length=self.scales_samples[0] // 4,
191
182
  alpha_momentum=self.alpha_momentum,
192
183
  activation=self.activation,
193
- drop_prob=self.drop_prob
184
+ drop_prob=self.drop_prob,
194
185
  )
195
186
  block22 = self._get_inception_branch_2(
196
187
  in_channels=n_concat_dw_filters,
@@ -198,7 +189,7 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
198
189
  kernel_length=self.scales_samples[1] // 4,
199
190
  alpha_momentum=self.alpha_momentum,
200
191
  activation=self.activation,
201
- drop_prob=self.drop_prob
192
+ drop_prob=self.drop_prob,
202
193
  )
203
194
  block23 = self._get_inception_branch_2(
204
195
  in_channels=n_concat_dw_filters,
@@ -206,44 +197,44 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
206
197
  kernel_length=self.scales_samples[2] // 4,
207
198
  alpha_momentum=self.alpha_momentum,
208
199
  activation=self.activation,
209
- drop_prob=self.drop_prob
200
+ drop_prob=self.drop_prob,
210
201
  )
211
202
 
212
203
  self.add_module(
213
- "inception_block_2", _InceptionBlock((block21, block22, block23)))
204
+ "inception_block_2", InceptionBlock((block21, block22, block23))
205
+ )
214
206
 
215
207
  self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
216
208
 
217
- self.add_module("final_block", nn.Sequential(
218
- nn.Conv2d(
219
- n_concat_filters,
220
- n_concat_filters // 2,
221
- (1, 8),
222
- padding="same",
223
- bias=False
209
+ self.add_module(
210
+ "final_block",
211
+ nn.Sequential(
212
+ nn.Conv2d(
213
+ n_concat_filters,
214
+ n_concat_filters // 2,
215
+ (1, 8),
216
+ padding="same",
217
+ bias=False,
218
+ ),
219
+ nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
220
+ activation(),
221
+ nn.Dropout(self.drop_prob),
222
+ nn.AvgPool2d((1, self.pooling_sizes[2])),
223
+ nn.Conv2d(
224
+ n_concat_filters // 2,
225
+ n_concat_filters // 4,
226
+ (1, 4),
227
+ padding="same",
228
+ bias=False,
229
+ ),
230
+ nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
231
+ activation(),
232
+ nn.Dropout(self.drop_prob),
233
+ nn.AvgPool2d((1, self.pooling_sizes[3])),
224
234
  ),
225
- nn.BatchNorm2d(n_concat_filters // 2,
226
- momentum=self.alpha_momentum),
227
- activation,
228
- nn.Dropout(self.drop_prob),
229
- nn.AvgPool2d((1, self.pooling_sizes[2])),
235
+ )
230
236
 
231
- nn.Conv2d(
232
- n_concat_filters // 2,
233
- n_concat_filters // 4,
234
- (1, 4),
235
- padding="same",
236
- bias=False
237
- ),
238
- nn.BatchNorm2d(n_concat_filters // 4,
239
- momentum=self.alpha_momentum),
240
- activation,
241
- nn.Dropout(self.drop_prob),
242
- nn.AvgPool2d((1, self.pooling_sizes[3])),
243
- ))
244
-
245
- spatial_dim_last_layer = (
246
- self.n_times // prod(self.pooling_sizes))
237
+ spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
247
238
  n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
248
239
 
249
240
  self.add_module("flat", nn.Flatten())
@@ -251,63 +242,63 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
251
242
  # Incorporating classification module and subsequent ones in one final layer
252
243
  module = nn.Sequential()
253
244
 
254
- module.add_module("fc",
255
- nn.Linear(
256
- spatial_dim_last_layer * n_channels_last_layer,
257
- self.n_outputs
258
- ), )
245
+ module.add_module(
246
+ "fc",
247
+ nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
248
+ )
259
249
 
260
- if self.add_log_softmax:
261
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
262
- else:
263
- module.add_module("identity", nn.Identity())
250
+ module.add_module("identity", nn.Identity())
264
251
 
265
252
  self.add_module("final_layer", module)
266
253
 
267
- _glorot_weight_zero_bias(self)
254
+ glorot_weight_zero_bias(self)
268
255
 
269
256
  @staticmethod
270
- def _get_inception_branch_1(in_channels, out_channels, kernel_length,
271
- alpha_momentum, drop_prob, activation,
272
- depth_multiplier):
257
+ def _get_inception_branch_1(
258
+ in_channels,
259
+ out_channels,
260
+ kernel_length,
261
+ alpha_momentum,
262
+ drop_prob,
263
+ activation,
264
+ depth_multiplier,
265
+ ):
273
266
  return nn.Sequential(
274
267
  nn.Conv2d(
275
268
  1,
276
269
  out_channels,
277
270
  kernel_size=(1, kernel_length),
278
271
  padding="same",
279
- bias=True
272
+ bias=True,
280
273
  ),
281
274
  nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
282
- activation,
275
+ activation(),
283
276
  nn.Dropout(drop_prob),
284
- _DepthwiseConv2d(
277
+ DepthwiseConv2d(
285
278
  out_channels,
286
279
  kernel_size=(in_channels, 1),
287
280
  depth_multiplier=depth_multiplier,
288
281
  bias=False,
289
282
  padding="valid",
290
283
  ),
291
- nn.BatchNorm2d(
292
- depth_multiplier * out_channels,
293
- momentum=alpha_momentum
294
- ),
295
- activation,
284
+ nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
285
+ activation(),
296
286
  nn.Dropout(drop_prob),
297
287
  )
298
288
 
299
289
  @staticmethod
300
- def _get_inception_branch_2(in_channels, out_channels, kernel_length,
301
- alpha_momentum, drop_prob, activation):
290
+ def _get_inception_branch_2(
291
+ in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
292
+ ):
302
293
  return nn.Sequential(
303
294
  nn.Conv2d(
304
295
  in_channels,
305
296
  out_channels,
306
297
  kernel_size=(1, kernel_length),
307
298
  padding="same",
308
- bias=False
299
+ bias=False,
309
300
  ),
310
301
  nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
311
- activation,
302
+ activation(),
312
303
  nn.Dropout(drop_prob),
313
304
  )