braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +0,0 @@
1
- # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
- #
3
- # License: BSD (3-clause)
4
-
5
- from warnings import warn
6
- from ..datasets.xy import * # noqa: F401,F403
7
-
8
- warn('datautil.xy module is deprecated and is now under '
9
- 'datasets.xy, please use from import braindecode.datasets.xy')
@@ -1,317 +0,0 @@
1
- # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
- # Cedric Rommel <cedric.rommel@inria.fr>
3
- #
4
- # License: BSD (3-clause)
5
- from warnings import warn
6
-
7
- from numpy import prod
8
-
9
- from torch import nn
10
- from einops.layers.torch import Rearrange
11
- from .modules import Ensure4d
12
- from .eegnet import _glorot_weight_zero_bias
13
- from .eegitnet import _InceptionBlock, _DepthwiseConv2d
14
- from .base import EEGModuleMixin, deprecated_args
15
-
16
-
17
- class EEGInception(EEGModuleMixin, nn.Sequential):
18
- """ EEG Inception for ERP-based classification
19
-
20
- --> DEPRECATED <--
21
- THIS CLASS IS DEPRECATED AND WILL BE REMOVED IN THE RELEASE 0.9 OF
22
- BRAINDECODE. PLEASE USE braindecode.models.EEGInceptionERP INSTEAD IN THE
23
- FUTURE.
24
-
25
- The code for the paper and this model is also available at [Santamaria2020]_
26
- and an adaptation for PyTorch [2]_.
27
-
28
- The model is strongly based on the original InceptionNet for an image. The main goal is
29
- to extract features in parallel with different scales. The authors extracted three scales
30
- proportional to the window sample size. The network had three parts:
31
- 1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
32
- for classification.
33
-
34
- One advantage of the EEG-Inception block is that it allows a network
35
- to learn simultaneous components of low and high frequency associated with the signal.
36
- The winners of BEETL Competition/NeurIps 2021 used parts of the model [beetl]_.
37
-
38
- The model is fully described in [Santamaria2020]_.
39
-
40
- Notes
41
- -----
42
- This implementation is not guaranteed to be correct, has not been checked
43
- by original authors, only reimplemented from the paper based on [2]_.
44
-
45
- Parameters
46
- ----------
47
- drop_prob : float
48
- Dropout rate inside all the network.
49
- scales_time: list(int)
50
- Windows for inception block, must be a list with proportional values of
51
- the input_size_ms.
52
- According to the authors: temporal scale (ms) of the convolutions
53
- on each Inception module.
54
- This parameter determines the kernel sizes of the filters.
55
- n_filters : int
56
- Initial number of convolutional filters. Set to 8 in [Santamaria2020]_.
57
- activation: nn.Module
58
- Activation function, default: ELU activation.
59
- batch_norm_alpha: float
60
- Momentum for BatchNorm2d.
61
- depth_multiplier: int
62
- Depth multiplier for the depthwise convolution.
63
- pooling_sizes: list(int)
64
- Pooling sizes for the inception block.
65
- in_channels : int
66
- Alias for n_chans.
67
- n_classes : int
68
- Alias for n_outputs.
69
- input_window_samples : int
70
- Alias for input_window_seconds.
71
-
72
- References
73
- ----------
74
- .. [Santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
75
- Vaquerizo-Villar, F., & Hornero, R. (2020).
76
- EEG-inception: A novel deep convolutional neural network for assistive
77
- ERP-based brain-computer interfaces.
78
- IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
79
- Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
80
- .. [2] Grifcc. Implementation of the EEGInception in torch (2022).
81
- Online: https://github.com/Grifcc/EEG/tree/90e412a407c5242dfc953d5ffb490bdb32faf022
82
- .. [beetl]_ Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
83
- Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
84
- Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
85
- Lawhern, V.J., Śliwowski, M., Rouanne, V. &amp; Tempczyk, P.. (2022).
86
- 2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &amp;
87
- Heterogeneous EEG Data Sets. <i>Proceedings of the NeurIPS 2021 Competitions and
88
- Demonstrations Track</i>, in <i>Proceedings of Machine Learning Research</i>
89
- 176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
90
-
91
- """
92
-
93
- def __init__(
94
- self,
95
- n_chans=None,
96
- n_outputs=None,
97
- n_times=1000,
98
- sfreq=128,
99
- drop_prob=0.5,
100
- scales_samples_s=(0.5, 0.25, 0.125),
101
- n_filters=8,
102
- activation=nn.ELU(),
103
- batch_norm_alpha=0.01,
104
- depth_multiplier=2,
105
- pooling_sizes=(4, 2, 2, 2),
106
- chs_info=None,
107
- input_window_seconds=None,
108
- in_channels=None,
109
- n_classes=None,
110
- input_window_samples=None,
111
- add_log_softmax=True,
112
- ):
113
- n_chans, n_outputs, n_times, = deprecated_args(
114
- self,
115
- ('in_channels', 'n_chans', in_channels, n_chans),
116
- ('n_classes', 'n_outputs', n_classes, n_outputs),
117
- ('input_window_samples', 'n_times', input_window_samples, n_times),
118
- )
119
- super().__init__(
120
- n_outputs=n_outputs,
121
- n_chans=n_chans,
122
- chs_info=chs_info,
123
- n_times=n_times,
124
- input_window_seconds=input_window_seconds,
125
- sfreq=sfreq,
126
- add_log_softmax=add_log_softmax,
127
- )
128
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
129
- del in_channels, n_classes, input_window_samples
130
- warn(
131
- "The class EEGInception is deprecated and will be removed in the "
132
- "release 0.9 of braindecode. Please use "
133
- "braindecode.models.EEGInceptionERP instead in the future.",
134
- DeprecationWarning
135
- )
136
-
137
- self.drop_prob = drop_prob
138
- self.n_filters = n_filters
139
- self.scales_samples_s = scales_samples_s
140
- self.scales_samples = tuple(
141
- int(size_s * self.sfreq) for size_s in self.scales_samples_s)
142
- self.activation = activation
143
- self.alpha_momentum = batch_norm_alpha
144
- self.depth_multiplier = depth_multiplier
145
- self.pooling_sizes = pooling_sizes
146
-
147
- self.mapping = {
148
- 'classification.1.weight': 'final_layer.fc.weight',
149
- 'classification.1.bias': 'final_layer.fc.bias'}
150
-
151
- self.add_module("ensuredims", Ensure4d())
152
-
153
- self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
154
-
155
- # ======== Inception branches ========================
156
- block11 = self._get_inception_branch_1(
157
- in_channels=self.n_chans,
158
- out_channels=self.n_filters,
159
- kernel_length=self.scales_samples[0],
160
- alpha_momentum=self.alpha_momentum,
161
- activation=self.activation,
162
- drop_prob=self.drop_prob,
163
- depth_multiplier=self.depth_multiplier,
164
- )
165
- block12 = self._get_inception_branch_1(
166
- in_channels=self.n_chans,
167
- out_channels=self.n_filters,
168
- kernel_length=self.scales_samples[1],
169
- alpha_momentum=self.alpha_momentum,
170
- activation=self.activation,
171
- drop_prob=self.drop_prob,
172
- depth_multiplier=self.depth_multiplier,
173
- )
174
- block13 = self._get_inception_branch_1(
175
- in_channels=self.n_chans,
176
- out_channels=self.n_filters,
177
- kernel_length=self.scales_samples[2],
178
- alpha_momentum=self.alpha_momentum,
179
- activation=self.activation,
180
- drop_prob=self.drop_prob,
181
- depth_multiplier=self.depth_multiplier,
182
- )
183
-
184
- self.add_module("inception_block_1", _InceptionBlock((block11, block12, block13)))
185
-
186
- self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
187
-
188
- # ======== Inception branches ========================
189
- n_concat_filters = len(self.scales_samples) * self.n_filters
190
- n_concat_dw_filters = n_concat_filters * self.depth_multiplier
191
- block21 = self._get_inception_branch_2(
192
- in_channels=n_concat_dw_filters,
193
- out_channels=self.n_filters,
194
- kernel_length=self.scales_samples[0] // 4,
195
- alpha_momentum=self.alpha_momentum,
196
- activation=self.activation,
197
- drop_prob=self.drop_prob
198
- )
199
- block22 = self._get_inception_branch_2(
200
- in_channels=n_concat_dw_filters,
201
- out_channels=self.n_filters,
202
- kernel_length=self.scales_samples[1] // 4,
203
- alpha_momentum=self.alpha_momentum,
204
- activation=self.activation,
205
- drop_prob=self.drop_prob
206
- )
207
- block23 = self._get_inception_branch_2(
208
- in_channels=n_concat_dw_filters,
209
- out_channels=self.n_filters,
210
- kernel_length=self.scales_samples[2] // 4,
211
- alpha_momentum=self.alpha_momentum,
212
- activation=self.activation,
213
- drop_prob=self.drop_prob
214
- )
215
-
216
- self.add_module(
217
- "inception_block_2", _InceptionBlock((block21, block22, block23)))
218
-
219
- self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
220
-
221
- self.add_module("final_block", nn.Sequential(
222
- nn.Conv2d(
223
- n_concat_filters,
224
- n_concat_filters // 2,
225
- (1, 8),
226
- padding="same",
227
- bias=False
228
- ),
229
- nn.BatchNorm2d(n_concat_filters // 2,
230
- momentum=self.alpha_momentum),
231
- activation,
232
- nn.Dropout(self.drop_prob),
233
- nn.AvgPool2d((1, self.pooling_sizes[2])),
234
-
235
- nn.Conv2d(
236
- n_concat_filters // 2,
237
- n_concat_filters // 4,
238
- (1, 4),
239
- padding="same",
240
- bias=False
241
- ),
242
- nn.BatchNorm2d(n_concat_filters // 4,
243
- momentum=self.alpha_momentum),
244
- activation,
245
- nn.Dropout(self.drop_prob),
246
- nn.AvgPool2d((1, self.pooling_sizes[3])),
247
- ))
248
-
249
- spatial_dim_last_layer = (
250
- self.n_times // prod(self.pooling_sizes))
251
- n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
252
-
253
- self.add_module("flat", nn.Flatten())
254
-
255
- module = nn.Sequential()
256
-
257
- module.add_module("fc",
258
- nn.Linear(
259
- spatial_dim_last_layer * n_channels_last_layer,
260
- self.n_outputs
261
- ), )
262
-
263
- if self.add_log_softmax:
264
- module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
265
- else:
266
- module.add_module("identity", nn.Identity())
267
-
268
- # The conv_classifier will be the final_layer and the other ones will be incorporated
269
- self.add_module("final_layer", module)
270
-
271
- _glorot_weight_zero_bias(self)
272
-
273
- @staticmethod
274
- def _get_inception_branch_1(in_channels, out_channels, kernel_length,
275
- alpha_momentum, drop_prob, activation,
276
- depth_multiplier):
277
- return nn.Sequential(
278
- nn.Conv2d(
279
- 1,
280
- out_channels,
281
- kernel_size=(1, kernel_length),
282
- padding="same",
283
- bias=True
284
- ),
285
- nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
286
- activation,
287
- nn.Dropout(drop_prob),
288
- _DepthwiseConv2d(
289
- out_channels,
290
- kernel_size=(in_channels, 1),
291
- depth_multiplier=depth_multiplier,
292
- bias=False,
293
- padding="valid",
294
- ),
295
- nn.BatchNorm2d(
296
- depth_multiplier * out_channels,
297
- momentum=alpha_momentum
298
- ),
299
- activation,
300
- nn.Dropout(drop_prob),
301
- )
302
-
303
- @staticmethod
304
- def _get_inception_branch_2(in_channels, out_channels, kernel_length,
305
- alpha_momentum, drop_prob, activation):
306
- return nn.Sequential(
307
- nn.Conv2d(
308
- in_channels,
309
- out_channels,
310
- kernel_size=(1, kernel_length),
311
- padding="same",
312
- bias=False
313
- ),
314
- nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
315
- activation,
316
- nn.Dropout(drop_prob),
317
- )
@@ -1,47 +0,0 @@
1
- # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
- #
3
- # License: BSD (3-clause)
4
-
5
- import torch
6
-
7
-
8
- def square(x):
9
- return x * x
10
-
11
-
12
- def safe_log(x, eps=1e-6):
13
- """ Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
14
- return torch.log(torch.clamp(x, min=eps))
15
-
16
-
17
- def identity(x):
18
- return x
19
-
20
-
21
- def squeeze_final_output(x):
22
- """Removes empty dimension at end and potentially removes empty time
23
- dimension. It does not just use squeeze as we never want to remove
24
- first dimension.
25
-
26
- Returns
27
- -------
28
- x: torch.Tensor
29
- squeezed tensor
30
- """
31
-
32
- assert x.size()[3] == 1
33
- x = x[:, :, :, 0]
34
- if x.size()[2] == 1:
35
- x = x[:, :, 0]
36
- return x
37
-
38
-
39
- def transpose_time_to_spat(x):
40
- """Swap time and spatial dimensions.
41
-
42
- Returns
43
- -------
44
- x: torch.Tensor
45
- tensor in which last and first dimensions are swapped
46
- """
47
- return x.permute(0, 3, 2, 1)