braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,377 @@
1
+ # Authors: Tao Yang <sheeptao@outlook.com>
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ from typing import Type, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops.layers.torch import Rearrange
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class MSVTNet(EEGModuleMixin, nn.Module):
14
+ r"""MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
+
16
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
17
+
18
+ This model implements a multi-scale convolutional transformer network
19
+ for EEG signal classification, as described in [msvt2024]_.
20
+
21
+ .. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
22
+ :align: center
23
+ :alt: MSVTNet Architecture
24
+
25
+ Parameters
26
+ ----------
27
+ n_filters_list : list[int], optional
28
+ List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
29
+ conv1_kernels_size : list[int], optional
30
+ List of kernel sizes for the first convolution in each TSConv block,
31
+ by default (15, 31, 63, 125).
32
+ conv2_kernel_size : int, optional
33
+ Kernel size for the second convolution in TSConv blocks, by default 15.
34
+ depth_multiplier : int, optional
35
+ Depth multiplier for depthwise convolution, by default 2.
36
+ pool1_size : int, optional
37
+ Pooling size for the first pooling layer in TSConv blocks, by default 8.
38
+ pool2_size : int, optional
39
+ Pooling size for the second pooling layer in TSConv blocks, by default 7.
40
+ drop_prob : float, optional
41
+ Dropout probability for convolutional layers, by default 0.3.
42
+ num_heads : int, optional
43
+ Number of attention heads in the transformer encoder, by default 8.
44
+ ffn_expansion_factor : float, optional
45
+ Ratio to compute feedforward dimension in the transformer, by default 1.
46
+ att_drop_prob : float, optional
47
+ Dropout probability for the transformer, by default 0.5.
48
+ num_layers : int, optional
49
+ Number of transformer encoder layers, by default 2.
50
+ activation : Type[nn.Module], optional
51
+ Activation function class to use, by default nn.ELU.
52
+ return_features : bool, optional
53
+ Whether to return predictions from branch classifiers, by default False.
54
+
55
+ Notes
56
+ -----
57
+ This implementation is not guaranteed to be correct, has not been checked
58
+ by original authors, only reimplemented based on the original code [msvt2024code]_.
59
+
60
+ References
61
+ ----------
62
+ .. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
63
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
64
+ IEEE Journal of Biomedical an Health Informatics.
65
+ .. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
66
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
67
+ Source Code: https://github.com/SheepTAO/MSVTNet
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ # braindecode parameters
73
+ n_chans=None,
74
+ n_outputs=None,
75
+ n_times=None,
76
+ input_window_seconds=None,
77
+ sfreq=None,
78
+ chs_info=None,
79
+ # Model's parameters
80
+ n_filters_list: tuple[int, ...] = (9, 9, 9, 9),
81
+ conv1_kernels_size: tuple[int, ...] = (15, 31, 63, 125),
82
+ conv2_kernel_size: int = 15,
83
+ depth_multiplier: int = 2,
84
+ pool1_size: int = 8,
85
+ pool2_size: int = 7,
86
+ drop_prob: float = 0.3,
87
+ num_heads: int = 8,
88
+ ffn_expansion_factor: float = 1,
89
+ att_drop_prob: float = 0.5,
90
+ num_layers: int = 2,
91
+ activation: Type[nn.Module] = nn.ELU,
92
+ return_features: bool = False,
93
+ ):
94
+ super().__init__(
95
+ n_outputs=n_outputs,
96
+ n_chans=n_chans,
97
+ chs_info=chs_info,
98
+ n_times=n_times,
99
+ input_window_seconds=input_window_seconds,
100
+ sfreq=sfreq,
101
+ )
102
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
103
+
104
+ self.return_features = return_features
105
+ assert len(n_filters_list) == len(conv1_kernels_size), (
106
+ "The length of n_filters_list and conv1_kernel_sizes should be equal."
107
+ )
108
+
109
+ self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
110
+ self.mstsconv = nn.ModuleList(
111
+ [
112
+ nn.Sequential(
113
+ _TSConv(
114
+ self.n_chans,
115
+ n_filters_list[b],
116
+ conv1_kernels_size[b],
117
+ conv2_kernel_size,
118
+ depth_multiplier,
119
+ pool1_size,
120
+ pool2_size,
121
+ drop_prob,
122
+ activation,
123
+ ),
124
+ Rearrange("batch channels 1 time -> batch time channels"),
125
+ )
126
+ for b in range(len(n_filters_list))
127
+ ]
128
+ )
129
+ branch_linear_in = self._forward_flatten(cat=False)
130
+ self.branch_head = nn.ModuleList(
131
+ [
132
+ _DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
133
+ for b in range(len(n_filters_list))
134
+ ]
135
+ )
136
+
137
+ seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore
138
+ self.transformer = _Transformer(
139
+ seq_len,
140
+ d_model,
141
+ num_heads,
142
+ ffn_expansion_factor,
143
+ att_drop_prob,
144
+ num_layers,
145
+ )
146
+
147
+ linear_in = self._forward_flatten().shape[1] # type: ignore
148
+ self.flatten_layer = nn.Flatten()
149
+ self.final_layer = nn.Linear(linear_in, self.n_outputs)
150
+
151
+ def _forward_mstsconv(
152
+ self, cat: bool = True
153
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
154
+ x = torch.randn(1, 1, self.n_chans, self.n_times)
155
+ x = [tsconv(x) for tsconv in self.mstsconv]
156
+ if cat:
157
+ x = torch.cat(x, dim=2)
158
+ return x
159
+
160
+ def _forward_flatten(
161
+ self, cat: bool = True
162
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
163
+ x = self._forward_mstsconv(cat)
164
+ if cat:
165
+ x = self.transformer(x)
166
+ x = x.flatten(start_dim=1, end_dim=-1)
167
+ else:
168
+ x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
169
+ return x
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ # x with shape: (batch, n_chans, n_times)
173
+ x = self.ensure_dim(x)
174
+ # x with shape: (batch, 1, n_chans, n_times)
175
+ x_list = [tsconv(x) for tsconv in self.mstsconv]
176
+ # x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
177
+ branch_preds = [
178
+ branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
179
+ ]
180
+ # branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
181
+ x = torch.stack(x_list, dim=2)
182
+ x = x.view(x.size(0), x.size(1), -1)
183
+ # x shape after concatenation: [batch_size, seq_len, total_embed_dim]
184
+ x = self.transformer(x)
185
+ # x shape after transformer: [batch_size, embed_dim]
186
+
187
+ x = self.final_layer(x)
188
+ if self.return_features:
189
+ # x shape after final layer: [batch_size, num_classes]
190
+ # branch_preds shape: [batch_size, num_classes]
191
+ return torch.stack(branch_preds)
192
+ return x
193
+
194
+
195
+ class _TSConv(nn.Sequential):
196
+ r"""
197
+ Time-Distributed Separable Convolution block.
198
+
199
+ The architecture consists of:
200
+ - **Temporal Convolution**
201
+ - **Batch Normalization**
202
+ - **Depthwise Spatial Convolution**
203
+ - **Batch Normalization**
204
+ - **Activation Function**
205
+ - **First Pooling Layer**
206
+ - **Dropout**
207
+ - **Depthwise Temporal Convolution**
208
+ - **Batch Normalization**
209
+ - **Activation Function**
210
+ - **Second Pooling Layer**
211
+ - **Dropout**
212
+
213
+ Parameters
214
+ ----------
215
+ n_channels : int
216
+ Number of input channels (EEG channels).
217
+ n_filters : int
218
+ Number of filters for the convolution layers.
219
+ conv1_kernel_size : int
220
+ Kernel size for the first convolution layer.
221
+ conv2_kernel_size : int
222
+ Kernel size for the second convolution layer.
223
+ depth_multiplier : int
224
+ Depth multiplier for depthwise convolution.
225
+ pool1_size : int
226
+ Kernel size for the first pooling layer.
227
+ pool2_size : int
228
+ Kernel size for the second pooling layer.
229
+ drop_prob : float
230
+ Dropout probability.
231
+ activation : Type[nn.Module], optional
232
+ Activation function class to use, by default nn.ELU.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ n_channels: int,
238
+ n_filters: int,
239
+ conv1_kernel_size: int,
240
+ conv2_kernel_size: int,
241
+ depth_multiplier: int,
242
+ pool1_size: int,
243
+ pool2_size: int,
244
+ drop_prob: float,
245
+ activation: Type[nn.Module] = nn.ELU,
246
+ ):
247
+ super().__init__(
248
+ nn.Conv2d(
249
+ in_channels=1,
250
+ out_channels=n_filters,
251
+ kernel_size=(1, conv1_kernel_size),
252
+ padding="same",
253
+ bias=False,
254
+ ),
255
+ nn.BatchNorm2d(n_filters),
256
+ nn.Conv2d(
257
+ in_channels=n_filters,
258
+ out_channels=n_filters * depth_multiplier,
259
+ kernel_size=(n_channels, 1),
260
+ groups=n_filters,
261
+ bias=False,
262
+ ),
263
+ nn.BatchNorm2d(n_filters * depth_multiplier),
264
+ activation(),
265
+ nn.AvgPool2d(kernel_size=(1, pool1_size)),
266
+ nn.Dropout(drop_prob),
267
+ nn.Conv2d(
268
+ in_channels=n_filters * depth_multiplier,
269
+ out_channels=n_filters * depth_multiplier,
270
+ kernel_size=(1, conv2_kernel_size),
271
+ padding="same",
272
+ groups=n_filters * depth_multiplier,
273
+ bias=False,
274
+ ),
275
+ nn.BatchNorm2d(n_filters * depth_multiplier),
276
+ activation(),
277
+ nn.AvgPool2d(kernel_size=(1, pool2_size)),
278
+ nn.Dropout(drop_prob),
279
+ )
280
+
281
+
282
+ class _PositionalEncoding(nn.Module):
283
+ r"""
284
+ Positional encoding module that adds learnable positional embeddings.
285
+
286
+ Parameters
287
+ ----------
288
+ seq_length : int
289
+ Sequence length.
290
+ d_model : int
291
+ Dimensionality of the model.
292
+ """
293
+
294
+ def __init__(self, seq_length: int, d_model: int) -> None:
295
+ super().__init__()
296
+ self.seq_length = seq_length
297
+ self.d_model = d_model
298
+ self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
299
+
300
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
301
+ x = x + self.pe
302
+ return x
303
+
304
+
305
+ class _Transformer(nn.Module):
306
+ r"""
307
+ Transformer encoder module with learnable class token and positional encoding.
308
+
309
+ Parameters
310
+ ----------
311
+ seq_length : int
312
+ Sequence length of the input.
313
+ d_model : int
314
+ Dimensionality of the model.
315
+ num_heads : int
316
+ Number of heads in the multihead attention.
317
+ ffn_expansion_factor : float
318
+ Ratio to compute the dimension of the feedforward network.
319
+ drop_prob : float, optional
320
+ Dropout probability, by default 0.5.
321
+ num_layers : int, optional
322
+ Number of transformer encoder layers, by default 4.
323
+ """
324
+
325
+ def __init__(
326
+ self,
327
+ seq_length: int,
328
+ d_model: int,
329
+ num_heads: int,
330
+ ffn_expansion_factor: float,
331
+ drop_prob: float = 0.5,
332
+ num_layers: int = 4,
333
+ ) -> None:
334
+ super().__init__()
335
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
336
+ self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
337
+
338
+ dim_ff = int(d_model * ffn_expansion_factor)
339
+ self.dropout = nn.Dropout(drop_prob)
340
+ self.trans = nn.TransformerEncoder(
341
+ nn.TransformerEncoderLayer(
342
+ d_model,
343
+ num_heads,
344
+ dim_ff,
345
+ drop_prob,
346
+ batch_first=True,
347
+ norm_first=True,
348
+ ),
349
+ num_layers,
350
+ norm=nn.LayerNorm(d_model),
351
+ )
352
+
353
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
354
+ batch_size = x.shape[0]
355
+ x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
356
+ x = self.pos_embedding(x)
357
+ x = self.dropout(x)
358
+ return self.trans(x)[:, 0]
359
+
360
+
361
+ class _DenseLayers(nn.Sequential):
362
+ r"""
363
+ Final classification layers.
364
+
365
+ Parameters
366
+ ----------
367
+ linear_in : int
368
+ Input dimension to the linear layer.
369
+ n_classes : int
370
+ Number of output classes.
371
+ """
372
+
373
+ def __init__(self, linear_in: int, n_classes: int):
374
+ super().__init__(
375
+ nn.Flatten(),
376
+ nn.Linear(linear_in, n_classes),
377
+ )