braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,375 @@
1
+ # Authors: Tao Yang <sheeptao@outlook.com>
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ from typing import Dict, Optional, Type, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops.layers.torch import Rearrange
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class MSVTNet(EEGModuleMixin, nn.Module):
14
+ """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
+
16
+ This model implements a multi-scale convolutional transformer network
17
+ for EEG signal classification, as described in [msvt2024]_.
18
+
19
+ .. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
20
+ :align: center
21
+ :alt: MSVTNet Architecture
22
+
23
+ Parameters
24
+ ----------
25
+ n_filters_list : list[int], optional
26
+ List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
27
+ conv1_kernels_size : list[int], optional
28
+ List of kernel sizes for the first convolution in each TSConv block,
29
+ by default (15, 31, 63, 125).
30
+ conv2_kernel_size : int, optional
31
+ Kernel size for the second convolution in TSConv blocks, by default 15.
32
+ depth_multiplier : int, optional
33
+ Depth multiplier for depthwise convolution, by default 2.
34
+ pool1_size : int, optional
35
+ Pooling size for the first pooling layer in TSConv blocks, by default 8.
36
+ pool2_size : int, optional
37
+ Pooling size for the second pooling layer in TSConv blocks, by default 7.
38
+ drop_prob : float, optional
39
+ Dropout probability for convolutional layers, by default 0.3.
40
+ num_heads : int, optional
41
+ Number of attention heads in the transformer encoder, by default 8.
42
+ feedforward_ratio : float, optional
43
+ Ratio to compute feedforward dimension in the transformer, by default 1.
44
+ drop_prob_trans : float, optional
45
+ Dropout probability for the transformer, by default 0.5.
46
+ num_layers : int, optional
47
+ Number of transformer encoder layers, by default 2.
48
+ activation : Type[nn.Module], optional
49
+ Activation function class to use, by default nn.ELU.
50
+ return_features : bool, optional
51
+ Whether to return predictions from branch classifiers, by default False.
52
+
53
+ Notes
54
+ -----
55
+ This implementation is not guaranteed to be correct, has not been checked
56
+ by original authors, only reimplemented based on the original code [msvt2024code]_.
57
+
58
+ References
59
+ ----------
60
+ .. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
61
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
62
+ IEEE Journal of Biomedical an Health Informatics.
63
+ .. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
64
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
65
+ Source Code: https://github.com/SheepTAO/MSVTNet
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ # braindecode parameters
71
+ n_chans=None,
72
+ n_outputs=None,
73
+ n_times=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ chs_info=None,
77
+ # Model's parameters
78
+ n_filters_list: tuple[int, ...] = (9, 9, 9, 9),
79
+ conv1_kernels_size: tuple[int, ...] = (15, 31, 63, 125),
80
+ conv2_kernel_size: int = 15,
81
+ depth_multiplier: int = 2,
82
+ pool1_size: int = 8,
83
+ pool2_size: int = 7,
84
+ drop_prob: float = 0.3,
85
+ num_heads: int = 8,
86
+ feedforward_ratio: float = 1,
87
+ drop_prob_trans: float = 0.5,
88
+ num_layers: int = 2,
89
+ activation: Type[nn.Module] = nn.ELU,
90
+ return_features: bool = False,
91
+ ):
92
+ super().__init__(
93
+ n_outputs=n_outputs,
94
+ n_chans=n_chans,
95
+ chs_info=chs_info,
96
+ n_times=n_times,
97
+ input_window_seconds=input_window_seconds,
98
+ sfreq=sfreq,
99
+ )
100
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
101
+
102
+ self.return_features = return_features
103
+ assert len(n_filters_list) == len(conv1_kernels_size), (
104
+ "The length of n_filters_list and conv1_kernel_sizes should be equal."
105
+ )
106
+
107
+ self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
108
+ self.mstsconv = nn.ModuleList(
109
+ [
110
+ nn.Sequential(
111
+ _TSConv(
112
+ self.n_chans,
113
+ n_filters_list[b],
114
+ conv1_kernels_size[b],
115
+ conv2_kernel_size,
116
+ depth_multiplier,
117
+ pool1_size,
118
+ pool2_size,
119
+ drop_prob,
120
+ activation,
121
+ ),
122
+ Rearrange("batch channels 1 time -> batch time channels"),
123
+ )
124
+ for b in range(len(n_filters_list))
125
+ ]
126
+ )
127
+ branch_linear_in = self._forward_flatten(cat=False)
128
+ self.branch_head = nn.ModuleList(
129
+ [
130
+ _DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
131
+ for b in range(len(n_filters_list))
132
+ ]
133
+ )
134
+
135
+ seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore
136
+ self.transformer = _Transformer(
137
+ seq_len,
138
+ d_model,
139
+ num_heads,
140
+ feedforward_ratio,
141
+ drop_prob_trans,
142
+ num_layers,
143
+ )
144
+
145
+ linear_in = self._forward_flatten().shape[1] # type: ignore
146
+ self.flatten_layer = nn.Flatten()
147
+ self.final_layer = nn.Linear(linear_in, self.n_outputs)
148
+
149
+ def _forward_mstsconv(
150
+ self, cat: bool = True
151
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
152
+ x = torch.randn(1, 1, self.n_chans, self.n_times)
153
+ x = [tsconv(x) for tsconv in self.mstsconv]
154
+ if cat:
155
+ x = torch.cat(x, dim=2)
156
+ return x
157
+
158
+ def _forward_flatten(
159
+ self, cat: bool = True
160
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
161
+ x = self._forward_mstsconv(cat)
162
+ if cat:
163
+ x = self.transformer(x)
164
+ x = x.flatten(start_dim=1, end_dim=-1)
165
+ else:
166
+ x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
167
+ return x
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ # x with shape: (batch, n_chans, n_times)
171
+ x = self.ensure_dim(x)
172
+ # x with shape: (batch, 1, n_chans, n_times)
173
+ x_list = [tsconv(x) for tsconv in self.mstsconv]
174
+ # x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
175
+ branch_preds = [
176
+ branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
177
+ ]
178
+ # branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
179
+ x = torch.stack(x_list, dim=2)
180
+ x = x.view(x.size(0), x.size(1), -1)
181
+ # x shape after concatenation: [batch_size, seq_len, total_embed_dim]
182
+ x = self.transformer(x)
183
+ # x shape after transformer: [batch_size, embed_dim]
184
+
185
+ x = self.final_layer(x)
186
+ if self.return_features:
187
+ # x shape after final layer: [batch_size, num_classes]
188
+ # branch_preds shape: [batch_size, num_classes]
189
+ return torch.stack(branch_preds)
190
+ return x
191
+
192
+
193
+ class _TSConv(nn.Sequential):
194
+ """
195
+ Time-Distributed Separable Convolution block.
196
+
197
+ The architecture consists of:
198
+ - **Temporal Convolution**
199
+ - **Batch Normalization**
200
+ - **Depthwise Spatial Convolution**
201
+ - **Batch Normalization**
202
+ - **Activation Function**
203
+ - **First Pooling Layer**
204
+ - **Dropout**
205
+ - **Depthwise Temporal Convolution**
206
+ - **Batch Normalization**
207
+ - **Activation Function**
208
+ - **Second Pooling Layer**
209
+ - **Dropout**
210
+
211
+ Parameters
212
+ ----------
213
+ n_channels : int
214
+ Number of input channels (EEG channels).
215
+ n_filters : int
216
+ Number of filters for the convolution layers.
217
+ conv1_kernel_size : int
218
+ Kernel size for the first convolution layer.
219
+ conv2_kernel_size : int
220
+ Kernel size for the second convolution layer.
221
+ depth_multiplier : int
222
+ Depth multiplier for depthwise convolution.
223
+ pool1_size : int
224
+ Kernel size for the first pooling layer.
225
+ pool2_size : int
226
+ Kernel size for the second pooling layer.
227
+ drop_prob : float
228
+ Dropout probability.
229
+ activation : Type[nn.Module], optional
230
+ Activation function class to use, by default nn.ELU.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ n_channels: int,
236
+ n_filters: int,
237
+ conv1_kernel_size: int,
238
+ conv2_kernel_size: int,
239
+ depth_multiplier: int,
240
+ pool1_size: int,
241
+ pool2_size: int,
242
+ drop_prob: float,
243
+ activation: Type[nn.Module] = nn.ELU,
244
+ ):
245
+ super().__init__(
246
+ nn.Conv2d(
247
+ in_channels=1,
248
+ out_channels=n_filters,
249
+ kernel_size=(1, conv1_kernel_size),
250
+ padding="same",
251
+ bias=False,
252
+ ),
253
+ nn.BatchNorm2d(n_filters),
254
+ nn.Conv2d(
255
+ in_channels=n_filters,
256
+ out_channels=n_filters * depth_multiplier,
257
+ kernel_size=(n_channels, 1),
258
+ groups=n_filters,
259
+ bias=False,
260
+ ),
261
+ nn.BatchNorm2d(n_filters * depth_multiplier),
262
+ activation(),
263
+ nn.AvgPool2d(kernel_size=(1, pool1_size)),
264
+ nn.Dropout(drop_prob),
265
+ nn.Conv2d(
266
+ in_channels=n_filters * depth_multiplier,
267
+ out_channels=n_filters * depth_multiplier,
268
+ kernel_size=(1, conv2_kernel_size),
269
+ padding="same",
270
+ groups=n_filters * depth_multiplier,
271
+ bias=False,
272
+ ),
273
+ nn.BatchNorm2d(n_filters * depth_multiplier),
274
+ activation(),
275
+ nn.AvgPool2d(kernel_size=(1, pool2_size)),
276
+ nn.Dropout(drop_prob),
277
+ )
278
+
279
+
280
+ class _PositionalEncoding(nn.Module):
281
+ """
282
+ Positional encoding module that adds learnable positional embeddings.
283
+
284
+ Parameters
285
+ ----------
286
+ seq_length : int
287
+ Sequence length.
288
+ d_model : int
289
+ Dimensionality of the model.
290
+ """
291
+
292
+ def __init__(self, seq_length: int, d_model: int) -> None:
293
+ super().__init__()
294
+ self.seq_length = seq_length
295
+ self.d_model = d_model
296
+ self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ x = x + self.pe
300
+ return x
301
+
302
+
303
+ class _Transformer(nn.Module):
304
+ """
305
+ Transformer encoder module with learnable class token and positional encoding.
306
+
307
+ Parameters
308
+ ----------
309
+ seq_length : int
310
+ Sequence length of the input.
311
+ d_model : int
312
+ Dimensionality of the model.
313
+ num_heads : int
314
+ Number of heads in the multihead attention.
315
+ feedforward_ratio : float
316
+ Ratio to compute the dimension of the feedforward network.
317
+ drop_prob : float, optional
318
+ Dropout probability, by default 0.5.
319
+ num_layers : int, optional
320
+ Number of transformer encoder layers, by default 4.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ seq_length: int,
326
+ d_model: int,
327
+ num_heads: int,
328
+ feedforward_ratio: float,
329
+ drop_prob: float = 0.5,
330
+ num_layers: int = 4,
331
+ ) -> None:
332
+ super().__init__()
333
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
334
+ self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
335
+
336
+ dim_ff = int(d_model * feedforward_ratio)
337
+ self.dropout = nn.Dropout(drop_prob)
338
+ self.trans = nn.TransformerEncoder(
339
+ nn.TransformerEncoderLayer(
340
+ d_model,
341
+ num_heads,
342
+ dim_ff,
343
+ drop_prob,
344
+ batch_first=True,
345
+ norm_first=True,
346
+ ),
347
+ num_layers,
348
+ norm=nn.LayerNorm(d_model),
349
+ )
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ batch_size = x.shape[0]
353
+ x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
354
+ x = self.pos_embedding(x)
355
+ x = self.dropout(x)
356
+ return self.trans(x)[:, 0]
357
+
358
+
359
+ class _DenseLayers(nn.Sequential):
360
+ """
361
+ Final classification layers.
362
+
363
+ Parameters
364
+ ----------
365
+ linear_in : int
366
+ Input dimension to the linear layer.
367
+ n_classes : int
368
+ Number of output classes.
369
+ """
370
+
371
+ def __init__(self, linear_in: int, n_classes: int):
372
+ super().__init__(
373
+ nn.Flatten(),
374
+ nn.Linear(linear_in, n_classes),
375
+ )
@@ -0,0 +1,182 @@
1
+ # Authors: Chun-Shu Wei
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import math
7
+
8
+ import torch
9
+ from einops.layers.torch import Rearrange
10
+ from torch import nn
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import LogActivation
14
+
15
+
16
+ class SCCNet(EEGModuleMixin, nn.Module):
17
+ """SCCNet from Wei, C S (2019) [sccnet]_.
18
+
19
+ Spatial component-wise convolutional network (SCCNet) for motor-imagery EEG
20
+ classification.
21
+
22
+ .. figure:: https://dt5vp8kor0orz.cloudfront.net/6e3ec5d729cd51fe8acc5a978db27d02a5df9e05/2-Figure1-1.png
23
+ :align: center
24
+ :alt: Spatial component-wise convolutional network
25
+
26
+
27
+ 1. **Spatial Component Analysis**: Performs convolution spatial filtering
28
+ across all EEG channels to extract spatial components, effectively
29
+ reducing the channel dimension.
30
+ 2. **Spatio-Temporal Filtering**: Applies convolution across the spatial
31
+ components and temporal domain to capture spatio-temporal patterns.
32
+ 3. **Temporal Smoothing (Pooling)**: Uses average pooling over time to smooth the
33
+ features and reduce the temporal dimension, focusing on longer-term patterns.
34
+ 4. **Classification**: Flattens the features and applies a fully connected
35
+ layer.
36
+
37
+
38
+ Parameters
39
+ ----------
40
+ n_spatial_filters : int, optional
41
+ Number of spatial filters in the first convolutional layer. Default is 22.
42
+ n_spatial_filters_smooth : int, optional
43
+ Number of spatial filters used as filter in the second convolutional
44
+ layer. Default is 20.
45
+ drop_prob : float, optional
46
+ Dropout probability. Default is 0.5.
47
+ activation : nn.Module, optional
48
+ Activation function after the second convolutional layer. Default is
49
+ logarithm activation.
50
+
51
+ Notes
52
+ -----
53
+ This implementation is not guaranteed to be correct, has not been checked
54
+ by original authors, only reimplemented from the paper description and
55
+ the source that have not been tested [sccnetcode]_.
56
+
57
+
58
+ References
59
+ ----------
60
+ .. [sccnet] Wei, C. S., Koike-Akino, T., & Wang, Y. (2019, March). Spatial
61
+ component-wise convolutional network (SCCNet) for motor-imagery EEG
62
+ classification. In 2019 9th International IEEE/EMBS Conference on
63
+ Neural Engineering (NER) (pp. 328-331). IEEE.
64
+ .. [sccnetcode] Hsieh, C. Y., Chou, J. L., Chang, Y. H., & Wei, C. S.
65
+ XBrainLab: An Open-Source Software for Explainable Artificial
66
+ Intelligence-Based EEG Analysis. In NeurIPS 2023 AI for
67
+ Science Workshop.
68
+
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ # Signal related parameters
74
+ n_chans=None,
75
+ n_outputs=None,
76
+ n_times=None,
77
+ chs_info=None,
78
+ input_window_seconds=None,
79
+ sfreq=None,
80
+ # Model related parameters
81
+ n_spatial_filters: int = 22,
82
+ n_spatial_filters_smooth: int = 20,
83
+ drop_prob: float = 0.5,
84
+ activation: nn.Module = LogActivation,
85
+ batch_norm_momentum: float = 0.1,
86
+ ):
87
+ super().__init__(
88
+ n_outputs=n_outputs,
89
+ n_chans=n_chans,
90
+ chs_info=chs_info,
91
+ n_times=n_times,
92
+ input_window_seconds=input_window_seconds,
93
+ sfreq=sfreq,
94
+ )
95
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
96
+ # Parameters
97
+ self.n_spatial_filters = n_spatial_filters
98
+ self.n_spatial_filters_smooth = n_spatial_filters_smooth
99
+ self.drop_prob = drop_prob
100
+
101
+ self.samples_100ms = int(math.floor(self.sfreq * 0.1))
102
+ self.kernel_size_pool = int(self.sfreq * 0.5)
103
+ # Equivalent to 0.5 seconds
104
+
105
+ num_features = self._calc_num_features()
106
+
107
+ # Layers
108
+ self.ensure_dim = Rearrange("batch nchan times -> batch 1 nchan times")
109
+
110
+ self.activation = LogActivation() if activation is None else activation()
111
+
112
+ self.spatial_conv = nn.Conv2d(
113
+ in_channels=1,
114
+ out_channels=self.n_spatial_filters,
115
+ kernel_size=(self.n_chans, 1),
116
+ )
117
+
118
+ self.spatial_batch_norm = nn.BatchNorm2d(
119
+ self.n_spatial_filters, momentum=batch_norm_momentum
120
+ )
121
+
122
+ self.permute = Rearrange(
123
+ "batch filspat nchans time -> batch nchans filspat time"
124
+ )
125
+
126
+ self.spatial_filt_conv = nn.Conv2d(
127
+ in_channels=1,
128
+ out_channels=self.n_spatial_filters_smooth,
129
+ kernel_size=(self.n_spatial_filters, self.samples_100ms),
130
+ bias=False,
131
+ )
132
+ self.batch_norm = nn.BatchNorm2d(
133
+ self.n_spatial_filters_smooth, momentum=batch_norm_momentum
134
+ )
135
+
136
+ self.dropout = nn.Dropout(self.drop_prob)
137
+ self.temporal_smoothing = nn.AvgPool2d(
138
+ kernel_size=(1, int(self.sfreq / 2)),
139
+ stride=(1, self.samples_100ms),
140
+ )
141
+
142
+ self.final_layer = nn.Linear(num_features, self.n_outputs)
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ # Shape: (batch_size, n_chans, n_times)
146
+ x = self.ensure_dim(x)
147
+ # Shape: (batch_size, 1, n_chans, n_times)
148
+ x = self.spatial_conv(x)
149
+ # Shape: (batch_size, n_filters, 1, n_times)
150
+ x = self.spatial_batch_norm(x)
151
+ # Shape: (batch_size, n_filters, 1, n_times)
152
+ x = self.permute(x)
153
+ # Shape: (batch_size, 1, n_filters, n_times)
154
+ x = self.spatial_filt_conv(x)
155
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
156
+ x = self.batch_norm(x)
157
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
158
+ x = torch.pow(x, 2)
159
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
160
+ x = self.dropout(x)
161
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
162
+ x = self.temporal_smoothing(x)
163
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
164
+ x = self.activation(x)
165
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
166
+ x = x.view(x.size(0), -1)
167
+ # Shape: (batch_size, n_filters_filt*n_times_reduced_avg_pool)
168
+ x = self.final_layer(x)
169
+ # Shape: (batch_size, n_outputs)
170
+ return x
171
+
172
+ def _calc_num_features(self) -> int:
173
+ # Compute the number of features for the final linear layer
174
+ w_out_conv2 = (
175
+ self.n_times - self.samples_100ms + 1 # After second conv layer
176
+ )
177
+ w_out_pool = (
178
+ (w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
179
+ # After pooling layer
180
+ )
181
+ num_features = self.n_spatial_filters_smooth * w_out_pool
182
+ return num_features